trustgraph/trustgraph/graph_rag.py
2024-07-23 23:02:09 +01:00

239 lines
5.3 KiB
Python

from trustgraph.trustgraph import TrustGraph
from trustgraph.triple_vectors import TripleVectors
from trustgraph.trustgraph import TrustGraph
from trustgraph.llm_client import LlmClient
from trustgraph.embeddings_client import EmbeddingsClient
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class GraphRag:
def __init__(
self,
graph_hosts=None,
pulsar_host="pulsar://pulsar:6650",
vector_store="http://milvus:19530",
verbose=False,
entity_limit=50,
triple_limit=30,
max_subgraph_size=3000,
module="test",
):
self.verbose=verbose
if graph_hosts == None:
graph_hosts = ["cassandra"]
if self.verbose:
print("Initialising...", flush=True)
self.graph = TrustGraph(graph_hosts)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-emb",
)
self.vecstore = TripleVectors(vector_store)
self.entity_limit=entity_limit
self.query_limit=triple_limit
self.max_subgraph_size=max_subgraph_size
self.label_cache = {}
self.llm = LlmClient(
pulsar_host=pulsar_host,
subscriber=module + "-llm",
)
if self.verbose:
print("Initialised", flush=True)
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
everything = []
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
for vector in vectors:
res = self.vecstore.search(
vector,
limit=self.entity_limit
)
print("Obtained", len(res), "entities")
entities = set([
item["entity"]["entity"]
for item in res
])
everything.extend(entities)
if self.verbose:
print("Entities:", flush=True)
for ent in everything:
print(" ", ent, flush=True)
return everything
def maybe_label(self, e):
if e in self.label_cache:
return self.label_cache[e]
res = self.graph.get_sp(e, LABEL)
res = list(res)
if len(res) == 0:
self.label_cache[e] = e
return e
self.label_cache[e] = res[0][0]
return self.label_cache[e]
def get_nodes(self, query):
ents = self.get_entities(query)
if self.verbose:
print("Get labels...", flush=True)
nodes = [
self.maybe_label(e)
for e in ents
]
if self.verbose:
print("Nodes:", flush=True)
for node in nodes:
print(" ", node, flush=True)
return nodes
def get_subgraph(self, query):
entities = self.get_entities(query)
subgraph = set()
if self.verbose:
print("Get subgraph...", flush=True)
for e in entities:
res = self.graph.get_s(e, limit=self.query_limit)
for p, o in res:
subgraph.add((e, p, o))
res = self.graph.get_p(e, limit=self.query_limit)
for s, o in res:
subgraph.add((s, e, o))
res = self.graph.get_o(e, limit=self.query_limit)
for s, p in res:
subgraph.add((s, p, e))
subgraph = list(subgraph)
subgraph = subgraph[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in subgraph:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
return sg2
def get_cypher(self, query):
sg = self.get_labelgraph(query)
sg2 = []
for s, p, o in sg:
sg2.append(f"({s})-[{p}]->({o})")
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
return kg
def get_graph_prompt(self, query):
kg = self.get_cypher(query)
prompt=f"""<instructions>Study the knowledge graph provided, and use
the information to answer the question. The question should be answered
in plain English only.
</instructions>
<knowledge-graph>
{kg}
</knowledge-graph>
<question>
{query}
</question>
"""
return prompt
def query(self, query):
if self.verbose:
print("Construct prompt...", flush=True)
prompt = self.get_graph_prompt(query)
if self.verbose:
print("Invoke LLM...", flush=True)
resp = self.llm.request(prompt)
if self.verbose:
print("Done", flush=True)
return resp