Prompt templates (#33)

* Added prompt-template, allows definiton, relationships and kg query
to be specified in config / command-line.

* Bump version & add prompt-templates to YAMLs

* Apply to graph rag flow

* Break out different templates
This commit is contained in:
cybermaggedon 2024-08-23 23:34:16 +01:00 committed by GitHub
parent e1ecf9f356
commit 6edc3f0ee1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 1268 additions and 298 deletions

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,19 @@
def to_relationships(template, text):
return template.format(text=text)
def to_definitions(template, text):
return template.format(text=text)
def get_cypher(kg):
sg2 = []
for f in kg:
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
return kg
def to_kg_query(template, query, kg):
cypher = get_cypher(kg)
return template.format(query=query, graph=cypher)

View file

@ -0,0 +1,283 @@
"""
Language service abstracts prompt engineering from LLM.
"""
import json
from .... schema import Definition, Relationship, Triple
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_kg_query
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
definition_template = params.get("definition_template")
relationship_template = params.get("relationship_template")
knowledge_query_template = params.get("knowledge_query_template")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": PromptRequest,
"output_schema": PromptResponse,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
"definition_template": definition_template,
"relationship_template": relationship_template,
"knowledge_query_template": knowledge_query_template,
}
)
self.llm = LlmClient(
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host
)
self.definition_template = definition_template
self.relationship_template = relationship_template
self.knowledge_query_template = knowledge_query_template
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
kind = v.kind
print(f"Handling kind {kind}...", flush=True)
if kind == "extract-definitions":
self.handle_extract_definitions(id, v)
return
elif kind == "extract-relationships":
self.handle_extract_relationships(id, v)
return
elif kind == "kg-prompt":
self.handle_kg_prompt(id, v)
return
else:
print("Invalid kind.", flush=True)
return
def handle_extract_definitions(self, id, v):
try:
prompt = to_definitions(self.definition_template, v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = json.loads(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
e = defn["entity"]
d = defn["definition"]
output.append(
Definition(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(definitions=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_extract_relationships(self, id, v):
try:
prompt = to_relationships(self.relationship_template, v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = json.loads(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
output.append(
Relationship(
s = defn["subject"],
p = defn["predicate"],
o = defn["object"],
o_entity = defn["object-entity"],
)
)
except Exception as e:
print("relationship fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(relationships=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_kg_prompt(self, id, v):
try:
prompt = to_kg_query(self.knowledge_query_template, v.query, v.kg)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
parser.add_argument(
'--definition-template',
required=True,
help=f'Definition extraction template',
)
parser.add_argument(
'--relationship-template',
required=True,
help=f'Relationship extraction template',
)
parser.add_argument(
'--knowledge-query-template',
required=True,
help=f'Knowledge query template',
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,3 @@
from . write import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,61 @@
"""
Accepts entity/vector pairs and writes them to a Milvus store.
"""
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... direct.milvus import TripleVectors
from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
default_store_uri = 'http://localhost:19530'
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddings,
"store_uri": store_uri,
}
)
self.vecstore = TripleVectors(store_uri)
def handle(self, msg):
v = msg.value()
if v.entity.value != "":
for vec in v.vectors:
self.vecstore.insert(vec, v.entity.value)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Milvus store URI (default: {default_store_uri})'
)
def run():
Processor.start(module, __doc__)