Breakout store queries (#8)

- Break out store queries, so not locked into a Milvus/Cassandra backend
- Break out prompting into a separate module, so that prompts can be tailored to other LLMs
- Jsonnet used to generate docker compose templates
- Version to 0.6.0
This commit is contained in:
cybermaggedon 2024-08-13 17:30:59 +01:00 committed by GitHub
parent a9a0e28f49
commit a3ea1301d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
70 changed files with 4286 additions and 2394 deletions

View file

View file

@ -67,7 +67,7 @@ class TrustGraph:
def get_s(self, s, limit=10):
return self.session.execute(
f"select p, o from triples where s = %s",
f"select p, o from triples where s = %s limit {limit}",
(s,)
)
@ -97,7 +97,7 @@ class TrustGraph:
def get_os(self, o, s, limit=10):
return self.session.execute(
f"select s from triples where o = %s and s = %s limit {limit}",
f"select p from triples where o = %s and s = %s limit {limit}",
(o, s)
)

View file

@ -9,10 +9,8 @@ import os
import argparse
import time
from .... trustgraph import TrustGraph
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
from . writer import ParquetWriter

View file

@ -9,10 +9,8 @@ import os
import argparse
import time
from .... trustgraph import TrustGraph
from .... schema import Triple
from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
from . writer import ParquetWriter

View file

@ -0,0 +1,89 @@
#!/usr/bin/env python3
import pulsar
import _pulsar
from pulsar.schema import JsonSchema
import hashlib
import uuid
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . schema import graph_embeddings_request_queue
from . schema import graph_embeddings_response_queue
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class GraphEmbeddingsClient:
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = graph_embeddings_request_queue
if output_queue == None:
output_queue = graph_embeddings_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(GraphEmbeddingsRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(GraphEmbeddingsResponse),
)
def request(self, vectors, limit=10, timeout=500):
id = str(uuid.uuid4())
r = GraphEmbeddingsRequest(
vectors=vectors,
limit=limit,
)
self.producer.send(r, properties={ "id": id })
while True:
msg = self.consumer.receive(timeout_millis=timeout * 1000)
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().entities
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -1,11 +1,19 @@
from trustgraph.trustgraph import TrustGraph
from trustgraph.triple_vectors import TripleVectors
from trustgraph.trustgraph import TrustGraph
from trustgraph.llm_client import LlmClient
from trustgraph.embeddings_client import EmbeddingsClient
from . schema import text_completion_request_queue
from . schema import text_completion_response_queue
from . graph_embeddings_client import GraphEmbeddingsClient
from . triples_query_client import TriplesQueryClient
from . embeddings_client import EmbeddingsClient
from . prompt_client import PromptClient
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import graph_embeddings_request_queue
from . schema import graph_embeddings_response_queue
from . schema import triples_request_queue
from . schema import triples_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
@ -14,13 +22,15 @@ class GraphRag:
def __init__(
self,
graph_hosts=None,
pulsar_host="pulsar://pulsar:6650",
vector_store="http://milvus:19530",
completion_request_queue=None,
completion_response_queue=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
ge_request_queue=None,
ge_response_queue=None,
tpl_request_queue=None,
tpl_response_queue=None,
verbose=False,
entity_limit=50,
triple_limit=30,
@ -30,25 +40,46 @@ class GraphRag:
self.verbose=verbose
if completion_request_queue == None:
completion_request_queue = text_completion_request_queue
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if completion_response_queue == None:
completion_response_queue = text_completion_response_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue == None:
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue == None:
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if graph_hosts == None:
graph_hosts = ["cassandra"]
if ge_request_queue is None:
ge_request_queue = graph_embeddings_request_queue
if ge_response_queue is None:
ge_response_queue = graph_embeddings_response_queue
if tpl_request_queue is None:
tpl_request_queue = triples_request_queue
if tpl_response_queue is None:
tpl_response_queue = triples_response_queue
if self.verbose:
print("Initialising...", flush=True)
self.graph = TrustGraph(graph_hosts)
self.ge_client = GraphEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-ge",
input_queue=ge_request_queue,
output_queue=ge_response_queue,
)
self.triples_client = TriplesQueryClient(
pulsar_host=pulsar_host,
subscriber=module + "-tpl",
input_queue=tpl_request_queue,
output_queue=tpl_response_queue
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
@ -57,19 +88,17 @@ class GraphRag:
subscriber=module + "-emb",
)
self.vecstore = TripleVectors(vector_store)
self.entity_limit=entity_limit
self.query_limit=triple_limit
self.max_subgraph_size=max_subgraph_size
self.label_cache = {}
self.llm = LlmClient(
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=completion_request_queue,
output_queue=completion_response_queue,
subscriber=module + "-llm",
input_queue=prompt_request_queue,
output_queue=prompt_response_queue,
subscriber=module + "-prompt",
)
if self.verbose:
@ -89,70 +118,43 @@ class GraphRag:
def get_entities(self, query):
everything = []
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
for vector in vectors:
entities = self.ge_client.request(
vectors, self.entity_limit
)
res = self.vecstore.search(
vector,
limit=self.entity_limit
)
print("Obtained", len(res), "entities")
entities = set([
item["entity"]["entity"]
for item in res
])
everything.extend(entities)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in everything:
for ent in entities:
print(" ", ent, flush=True)
return everything
return entities
def maybe_label(self, e):
if e in self.label_cache:
return self.label_cache[e]
res = self.graph.get_sp(e, LABEL)
res = list(res)
res = self.triples_client.request(
e, LABEL, None, limit=1
)
if len(res) == 0:
self.label_cache[e] = e
return e
self.label_cache[e] = res[0][0]
self.label_cache[e] = res[0].o.value
return self.label_cache[e]
def get_nodes(self, query):
ents = self.get_entities(query)
if self.verbose:
print("Get labels...", flush=True)
nodes = [
self.maybe_label(e)
for e in ents
]
if self.verbose:
print("Nodes:", flush=True)
for node in nodes:
print(" ", node, flush=True)
return nodes
def get_subgraph(self, query):
entities = self.get_entities(query)
@ -164,17 +166,35 @@ class GraphRag:
for e in entities:
res = self.graph.get_s(e, limit=self.query_limit)
for p, o in res:
subgraph.add((e, p, o))
res = self.triples_client.request(
e, None, None,
limit=self.query_limit
)
res = self.graph.get_p(e, limit=self.query_limit)
for s, o in res:
subgraph.add((s, e, o))
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.graph.get_o(e, limit=self.query_limit)
for s, p in res:
subgraph.add((s, p, e))
res = self.triples_client.request(
None, e, None,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.triples_client.request(
None, None, e,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
subgraph = list(subgraph)
@ -209,47 +229,19 @@ class GraphRag:
return sg2
def get_cypher(self, query):
sg = self.get_labelgraph(query)
sg2 = []
for s, p, o in sg:
sg2.append(f"({s})-[{p}]->({o})")
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
return kg
def get_graph_prompt(self, query):
kg = self.get_cypher(query)
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
Here's the knowledge statements:
{kg}
Use only the provided knowledge statements to respond to the following:
{query}
"""
return prompt
def query(self, query):
if self.verbose:
print("Construct prompt...", flush=True)
prompt = self.get_graph_prompt(query)
kg = self.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = self.llm.request(prompt)
resp = self.lang.request_kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)

View file

@ -9,11 +9,10 @@ import json
from ... schema import ChunkEmbeddings, Triple, Source, Value
from ... schema import chunk_embeddings_ingest_queue, triples_store_queue
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... log_level import LogLevel
from ... llm_client import LlmClient
from ... prompts import to_definitions
from ... prompt_client import PromptClient
from ... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
from ... base import ConsumerProducer
@ -32,11 +31,11 @@ class Processor(ConsumerProducer):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
@ -46,16 +45,16 @@ class Processor(ConsumerProducer):
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"output_schema": Triple,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.llm = LlmClient(
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
subscriber = module + "-llm",
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
def to_uri(self, text):
@ -68,12 +67,7 @@ class Processor(ConsumerProducer):
def get_definitions(self, chunk):
prompt = to_definitions(chunk)
resp = self.llm.request(prompt)
defs = json.loads(resp)
return defs
return self.prompt.request_definitions(chunk)
def emit_edge(self, s, p, o):
@ -90,14 +84,13 @@ class Processor(ConsumerProducer):
try:
defs = self.get_definitions(chunk)
print(json.dumps(defs, indent=4), flush=True)
for defn in defs:
s = defn["entity"]
s = defn.name
s_uri = self.to_uri(s)
o = defn["definition"]
o = defn.definition
if s == "": continue
if o == "": continue
@ -121,15 +114,15 @@ class Processor(ConsumerProducer):
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
'--prompt-completion-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run():

View file

@ -6,18 +6,16 @@ graph edges.
"""
import urllib.parse
import json
import os
from pulsar.schema import JsonSchema
from ... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
from ... schema import chunk_embeddings_ingest_queue, triples_store_queue
from ... schema import graph_embeddings_store_queue
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... log_level import LogLevel
from ... llm_client import LlmClient
from ... prompts import to_relationships
from ... prompt_client import PromptClient
from ... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
from ... base import ConsumerProducer
@ -38,11 +36,11 @@ class Processor(ConsumerProducer):
output_queue = params.get("output_queue", default_output_queue)
vector_queue = params.get("vector_queue", default_vector_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
@ -52,8 +50,8 @@ class Processor(ConsumerProducer):
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"output_schema": Triple,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
@ -66,19 +64,19 @@ class Processor(ConsumerProducer):
"input_queue": input_queue,
"output_queue": output_queue,
"vector_queue": vector_queue,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings.__name__,
"output_schema": Triple.__name__,
"vector_schema": GraphEmbeddings.__name__,
})
self.llm = LlmClient(
pulsar_host = self.pulsar_host,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
subscriber = module + "-llm",
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
def to_uri(self, text):
@ -91,12 +89,7 @@ class Processor(ConsumerProducer):
def get_relationships(self, chunk):
prompt = to_relationships(chunk)
resp = self.llm.request(prompt)
rels = json.loads(resp)
return rels
return self.prompt.request_relationships(chunk)
def emit_edge(self, s, p, o):
@ -118,13 +111,12 @@ class Processor(ConsumerProducer):
try:
rels = self.get_relationships(chunk)
print(json.dumps(rels, indent=4), flush=True)
for rel in rels:
s = rel["subject"]
p = rel["predicate"]
o = rel["object"]
s = rel.s
p = rel.p
o = rel.o
if s == "": continue
if p == "": continue
@ -136,7 +128,7 @@ class Processor(ConsumerProducer):
p_uri = self.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True)
if rel["object-entity"]:
if rel.o_entity:
o_uri = self.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True)
else:
@ -162,7 +154,7 @@ class Processor(ConsumerProducer):
Value(value=str(p), is_uri=False)
)
if rel["object-entity"]:
if rel.o_entity:
# Label for o
self.emit_edge(
o_value,
@ -172,7 +164,7 @@ class Processor(ConsumerProducer):
self.emit_vec(s_value, v.vectors)
self.emit_vec(p_value, v.vectors)
if rel["object-entity"]:
if rel.o_entity:
self.emit_vec(o_value, v.vectors)
except Exception as e:
@ -195,15 +187,15 @@ class Processor(ConsumerProducer):
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run():

View file

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,81 @@
def to_relationships(text):
prompt = f"""<instructions>
Study the following text and derive entity relationships. For each
relationship, derive the subject, predicate and object of the relationship.
Output relationships in JSON format as an arary of objects with fields:
- subject: the subject of the relationship
- predicate: the predicate
- object: the object of the relationship
- object-entity: false if the object is a simple data type: name, value or date. true if it is an entity.
</instructions>
<text>
{text}
</text>
<requirements>
You will respond only with raw JSON format data. Do not provide
explanations. Do not use special characters in the abstract text. The
abstract must be written as plain text. Do not add markdown formatting
or headers or prefixes.
</requirements>"""
return prompt
def to_definitions(text):
prompt = f"""<instructions>
Study the following text and derive definitions for any discovered entities.
Do not provide definitions for entities whose definitions are incomplete
or unknown.
Output relationships in JSON format as an arary of objects with fields:
- entity: the name of the entity
- definition: English text which defines the entity
</instructions>
<text>
{text}
</text>
<requirements>
You will respond only with raw JSON format data. Do not provide
explanations. Do not use special characters in the abstract text. The
abstract will be written as plain text. Do not add markdown formatting
or headers or prefixes. Do not include null or unknown definitions.
</requirements>"""
return prompt
def get_cypher(kg):
sg2 = []
for f in kg:
print(f)
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
print(sg2)
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
return kg
def to_kg_query(query, kg):
cypher = get_cypher(kg)
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
Here's the knowledge statements:
{cypher}
Use only the provided knowledge statements to respond to the following:
{query}
"""
return prompt

View file

@ -0,0 +1,195 @@
"""
Language service abstracts prompt engineering from LLM.
"""
import json
from .... schema import Definition, Relationship, Triple
from .... schema import PromptRequest, PromptResponse
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_kg_query
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": PromptRequest,
"output_schema": PromptResponse,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
}
)
self.llm = LlmClient(
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host
)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
kind = v.kind
print(f"Handling kind {kind}...", flush=True)
if kind == "extract-definitions":
self.handle_extract_definitions(id, v)
return
elif kind == "extract-relationships":
self.handle_extract_relationships(id, v)
return
elif kind == "kg-prompt":
self.handle_kg_prompt(id, v)
return
else:
print("Invalid kind.", flush=True)
return
def handle_extract_definitions(self, id, v):
prompt = to_definitions(v.chunk)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
defs = json.loads(ans)
output = []
for defn in defs:
try:
e = defn["entity"]
d = defn["definition"]
output.append(
Definition(
name=e, definition=d
)
)
except:
pass
print("Send response...", flush=True)
r = PromptResponse(definitions=output)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
def handle_extract_relationships(self, id, v):
prompt = to_relationships(v.chunk)
ans = self.llm.request(prompt)
defs = json.loads(ans)
output = []
for defn in defs:
try:
output.append(
Relationship(
s = defn["subject"],
p = defn["predicate"],
o = defn["object"],
o_entity = defn["object-entity"],
)
)
except Exception as e:
print(e)
print("Send response...", flush=True)
r = PromptResponse(relationships=output)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
def handle_kg_prompt(self, id, v):
prompt = to_kg_query(v.query, v.kg)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
def run():
Processor.start(module, __doc__)

143
trustgraph/prompt_client.py Normal file
View file

@ -0,0 +1,143 @@
#!/usr/bin/env python3
import pulsar
import _pulsar
from pulsar.schema import JsonSchema
import hashlib
import uuid
from . schema import PromptRequest, PromptResponse, Fact
from . schema import prompt_request_queue
from . schema import prompt_response_queue
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class PromptClient:
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = prompt_request_queue
if output_queue == None:
output_queue = prompt_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(PromptRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(PromptResponse),
)
def request_definitions(self, chunk, timeout=500):
id = str(uuid.uuid4())
r = PromptRequest(
kind="extract-definitions",
chunk=chunk,
)
self.producer.send(r, properties={ "id": id })
while True:
msg = self.consumer.receive(timeout_millis=timeout * 1000)
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().definitions
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
def request_relationships(self, chunk, timeout=500):
id = str(uuid.uuid4())
r = PromptRequest(
kind="extract-relationships",
chunk=chunk,
)
self.producer.send(r, properties={ "id": id })
while True:
msg = self.consumer.receive(timeout_millis=timeout * 1000)
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().relationships
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
def request_kg_prompt(self, query, kg, timeout=500):
id = str(uuid.uuid4())
r = PromptRequest(
kind="kg-prompt",
query=query,
kg=[
Fact(s=v[0], p=v[1], o=v[2])
for v in kg
],
)
self.producer.send(r, properties={ "id": id })
while True:
msg = self.consumer.receive(timeout_millis=timeout * 1000)
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().answer
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -1,138 +0,0 @@
def turtle_extract(text):
prompt = f"""<instructions>
Study the following text and extract knowledge as
information in Turtle RDF format.
When declaring any new URIs, use <https://trustgraph.ai/e#> prefix,
and declare appropriate namespace tags.
</instructions>
<text>
{text}
</text>
<requirements>
Do not use placeholders for information you do not know.
You will respond only with raw Turtle RDF data. Do not provide
explanations. Do not use special characters in the abstract text. The
abstract must be written as plain text. Do not add markdown formatting.
</requirements>"""
return prompt
def scholar(text):
# Build the prompt for Article style extraction
jsonexample = """{
"title": "Article title here",
"abstract": "Abstract text here",
"keywords": ["keyword1", "keyword2", "keyword3"],
"people": ["person1", "person2", "person3"]
}"""
promptscholar = f"""Your task is to read the provided text and write a scholarly abstract to fully explain all of the concepts described in the provided text. The abstract must include all conceptual details.
<text>
{text}
</text>
<instructions>
- Structure: For the provided text, write a title, abstract, keywords,
and people for the concepts found in the provided text. Ignore
document formatting in the provided text such as table of contents,
headers, footers, section metadata, and URLs.
- Focus on Concepts The abstract must focus on concepts found in the
provided text. The abstract must be factually accurate. Do not
write any concepts not found in the provided text. Do not
speculate. Do not omit any conceptual details.
- Completeness: The abstract must capture all topics the reader will
need to understand the concepts found in the provided text. Describe
all terms, definitions, entities, people, events, concepts,
conceptual relationships, and any other topics necessary for the
reader to understand the concepts of the provided text.
- Format: Respond in the form of a valid JSON object.
</instructions>
<example>
{jsonexample}
</example>
<requirements>
You will respond only with the JSON object. Do not provide
explanations. Do not use special characters in the abstract text. The
abstract must be written as plain text.
</requirements>"""
return promptscholar
def to_json_ld(text):
prompt = f"""<instructions>
Study the following text and output any facts you discover in
well-structured JSON-LD format.
Use any schema you understand from schema.org to describe the facts.
</instructions>
<text>
{text}
</text>
<requirements>
You will respond only with raw JSON-LD data in JSON format. Do not provide
explanations. Do not use special characters in the abstract text. The
abstract must be written as plain text. Do not add markdown formatting
or headers or prefixes. Do not use information which is not present in
the input text.
</requirements>"""
return prompt
def to_relationships(text):
prompt = f"""<instructions>
Study the following text and derive entity relationships. For each
relationship, derive the subject, predicate and object of the relationship.
Output relationships in JSON format as an arary of objects with fields:
- subject: the subject of the relationship
- predicate: the predicate
- object: the object of the relationship
- object-entity: false if the object is a simple data type: name, value or date. true if it is an entity.
</instructions>
<text>
{text}
</text>
<requirements>
You will respond only with raw JSON format data. Do not provide
explanations. Do not use special characters in the abstract text. The
abstract must be written as plain text. Do not add markdown formatting
or headers or prefixes.
</requirements>"""
return prompt
def to_definitions(text):
prompt = f"""<instructions>
Study the following text and derive definitions for any discovered entities.
Do not provide definitions for entities whose definitions are incomplete
or unknown.
Output relationships in JSON format as an arary of objects with fields:
- entity: the name of the entity
- definition: English text which defines the entity
</instructions>
<text>
{text}
</text>
<requirements>
You will respond only with raw JSON format data. Do not provide
explanations. Do not use special characters in the abstract text. The
abstract will be written as plain text. Do not add markdown formatting
or headers or prefixes. Do not include null or unknown definitions.
</requirements>"""
return prompt

View file

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . hf import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,100 @@
"""
Graph embeddings query service. Input is vector, output is list of
entities
"""
from .... direct.milvus import TripleVectors
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse, Value
from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue
default_subscriber = module
default_store_uri = 'http://localhost:19530'
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddingsRequest,
"output_schema": GraphEmbeddingsResponse,
"store_uri": store_uri,
}
)
self.vecstore = TripleVectors(store_uri)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
else:
return Value(value=ent, is_uri=False)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
entities = set()
for vec in v.vectors:
resp = self.vecstore.search(vec, limit=v.limit)
for r in resp:
ent = r["entity"]["entity"]
entities.add(ent)
# Convert set to list
entities = list(entities)
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Milvus store URI (default: {default_store_uri})'
)
def run():
Processor.start(module, __doc__)

View file

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . hf import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,153 @@
"""
Triples query service. Input is a (s, p, o) triple, some values may be
null. Output is a list of triples.
"""
from .... direct.cassandra import TrustGraph
from .... schema import TriplesQueryRequest, TriplesQueryResponse
from .... schema import Value, Triple
from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue
default_subscriber = module
default_graph_host='localhost'
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TriplesQueryRequest,
"output_schema": TriplesQueryResponse,
"graph_host": graph_host,
}
)
self.tg = TrustGraph([graph_host])
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
else:
return Value(value=ent, is_uri=False)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
triples = []
if v.s is not None:
if v.p is not None:
if v.o is not None:
resp = self.tg.get_spo(
v.s.value, v.p.value, v.o.value,
limit=v.limit
)
triples.append((v.s.value, v.p.value, v.o.value))
else:
resp = self.tg.get_sp(
v.s.value, v.p.value,
limit=v.limit
)
for t in resp:
triples.append((v.s.value, v.p.value, t.o))
else:
if v.o is not None:
resp = self.tg.get_os(
v.o.value, v.s.value,
limit=v.limit
)
for t in resp:
triples.append((v.s.value, t.p, v.o.value))
else:
resp = self.tg.get_s(
v.s.value,
limit=v.limit
)
for t in resp:
triples.append((v.s.value, t.p, t.o))
else:
if v.p is not None:
if v.o is not None:
resp = self.tg.get_po(
v.p.value, v.o.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, v.p.value, v.o.value))
else:
resp = self.tg.get_p(
v.p.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, v.p.value, t.o))
else:
if v.o is not None:
resp = self.tg.get_o(
v.o.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, t.p, v.o.value))
else:
resp = self.tg.get_all(
limit=v.limit
)
for t in resp:
triples.append((t.s, t.p, t.o))
triples = [
Triple(
s=self.create_value(t[0]),
p=self.create_value(t[1]),
o=self.create_value(t[2])
)
for t in triples
]
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-g', '--graph-host',
default="localhost",
help=f'Graph host (default: localhost)'
)
def run():
Processor.start(module, __doc__)

View file

@ -6,10 +6,14 @@ Input is query, output is response.
from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... schema import graph_embeddings_request_queue
from ... schema import graph_embeddings_response_queue
from ... schema import triples_request_queue
from ... schema import triples_response_queue
from ... log_level import LogLevel
from ... graph_rag import GraphRag
from ... base import ConsumerProducer
@ -19,8 +23,6 @@ module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_rag_request_queue
default_output_queue = graph_rag_response_queue
default_subscriber = module
default_graph_hosts = 'localhost'
default_vector_store = 'http://localhost:19530'
class Processor(ConsumerProducer):
@ -29,16 +31,14 @@ class Processor(ConsumerProducer):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_hosts = params.get("graph_hosts", default_graph_hosts)
vector_store = params.get("vector_store", default_vector_store)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 3000)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
@ -46,6 +46,18 @@ class Processor(ConsumerProducer):
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
ge_request_queue = params.get(
"graph_embeddings_request_queue", graph_embeddings_request_queue
)
ge_response_queue = params.get(
"graph_embeddings_response_queue", graph_embeddings_response_queue
)
tpl_request_queue = params.get(
"triples_request_queue", triples_request_queue
)
tpl_response_queue = params.get(
"triples_response_queue", triples_response_queue
)
super(Processor, self).__init__(
**params | {
@ -57,21 +69,27 @@ class Processor(ConsumerProducer):
"entity_limit": entity_limit,
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"graph_embeddings_request_queue": ge_request_queue,
"graph_embeddings_response_queue": ge_response_queue,
"triples_request_queue": triples_request_queue,
"triples_response_queue": triples_response_queue,
}
)
self.rag = GraphRag(
pulsar_host=self.pulsar_host,
graph_hosts=graph_hosts.split(","),
completion_request_queue=tc_request_queue,
completion_response_queue=tc_response_queue,
pr_request_queue=pr_request_queue,
pr_response_queue=pr_response_queue,
emb_request_queue=emb_request_queue,
emb_response_queue=emb_response_queue,
vector_store=vector_store,
ge_request_queue=ge_request_queue,
ge_response_queue=ge_response_queue,
tpl_request_queue=triples_request_queue,
tpl_response_queue=triples_response_queue,
verbose=True,
entity_limit=entity_limit,
triple_limit=triple_limit,
@ -139,15 +157,15 @@ class Processor(ConsumerProducer):
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
parser.add_argument(
@ -159,7 +177,31 @@ class Processor(ConsumerProducer):
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
help=f'Embeddings response queue (default: {embeddings_response_queue})',
)
parser.add_argument(
'--graph-embeddings-request-queue',
default=graph_embeddings_request_queue,
help=f'Graph embeddings request queue (default: {graph_embeddings_request_queue})',
)
parser.add_argument(
'--graph_embeddings-response-queue',
default=graph_embeddings_response_queue,
help=f'Graph embeddings response queue (default: {graph_embeddings_response_queue})',
)
parser.add_argument(
'--triples-request-queue',
default=triples_request_queue,
help=f'Triples request queue (default: {triples_request_queue})',
)
parser.add_argument(
'--triples-response-queue',
default=triples_response_queue,
help=f'Triples response queue (default: {triples_response_queue})',
)
def run():

View file

@ -71,6 +71,24 @@ graph_embeddings_store_queue = topic('graph-embeddings-store')
############################################################################
# Graph embeddings query
class GraphEmbeddingsRequest(Record):
vectors = Array(Array(Double()))
limit = Integer()
class GraphEmbeddingsResponse(Record):
entities = Array(Value())
graph_embeddings_request_queue = topic(
'graph-embeddings', kind='non-persistent', namespace='request'
)
graph_embeddings_response_queue = topic(
'graph-embeddings-response', kind='non-persistent', namespace='response',
)
############################################################################
# Graph triples
class Triple(Record):
@ -83,6 +101,26 @@ triples_store_queue = topic('triples-store')
############################################################################
# Triples query
class TriplesQueryRequest(Record):
s = Value()
p = Value()
o = Value()
limit = Integer()
class TriplesQueryResponse(Record):
triples = Array(Triple())
triples_request_queue = topic(
'triples', kind='non-persistent', namespace='request'
)
triples_response_queue = topic(
'triples-response', kind='non-persistent', namespace='response',
)
############################################################################
# chunk_embeddings_store_queue = topic('chunk-embeddings-store')
############################################################################
@ -138,3 +176,47 @@ graph_rag_response_queue = topic(
############################################################################
# Prompt services, abstract the prompt generation
class Definition(Record):
name = String()
definition = String()
class Relationship(Record):
s = String()
p = String()
o = String()
o_entity = Boolean()
class Fact(Record):
s = String()
p = String()
o = String()
# extract-definitions:
# chunk -> definitions
# extract-relationships:
# chunk -> relationships
# prompt-rag:
# query, triples -> answer
class PromptRequest(Record):
kind = String()
chunk = String()
query = String()
kg = Array(Fact())
class PromptResponse(Record):
answer = String()
definitions = Array(Definition())
relationships = Array(Relationship())
prompt_request_queue = topic(
'prompt', kind='non-persistent', namespace='request'
)
prompt_response_queue = topic(
'prompt-response', kind='non-persistent', namespace='response'
)
############################################################################

View file

@ -6,7 +6,7 @@ Accepts entity/vector pairs and writes them to a Milvus store.
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... triple_vectors import TripleVectors
from .... direct.milvus import TripleVectors
from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1])
@ -51,8 +51,8 @@ class Processor(Consumer):
parser.add_argument(
'-t', '--store-uri',
default="http://milvus:19530",
help=f'Milvus store URI (default: http://milvus:19530)'
default=default_store_uri,
help=f'Milvus store URI (default: {default_store_uri})'
)
def run():

View file

@ -9,7 +9,7 @@ import os
import argparse
import time
from .... trustgraph import TrustGraph
from .... direct.cassandra import TrustGraph
from .... schema import Triple
from .... schema import triples_store_queue
from .... log_level import LogLevel
@ -34,6 +34,7 @@ class Processor(Consumer):
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Triple,
"graph_host": graph_host,
}
)

View file

@ -0,0 +1,100 @@
#!/usr/bin/env python3
import pulsar
import _pulsar
from pulsar.schema import JsonSchema
import hashlib
import uuid
from . schema import TriplesQueryRequest, TriplesQueryResponse, Value
from . schema import triples_request_queue
from . schema import triples_response_queue
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class TriplesQueryClient:
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = triples_request_queue
if output_queue == None:
output_queue = triples_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(TriplesQueryRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(TriplesQueryResponse),
)
def create_value(self, ent):
if ent == None: return None
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Value(value=ent, is_uri=False)
def request(self, s, p, o, limit=10, timeout=500):
id = str(uuid.uuid4())
r = TriplesQueryRequest(
s=self.create_value(s),
p=self.create_value(p),
o=self.create_value(o),
limit=limit,
)
self.producer.send(r, properties={ "id": id })
while True:
msg = self.consumer.receive(timeout_millis=timeout * 1000)
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().triples
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()