Prompt refactor (#125)

* Prompt manager integrated and working with 6 tests
* Updated templates to for prompt-template update
This commit is contained in:
cybermaggedon 2024-10-26 22:17:43 +01:00 committed by GitHub
parent 51aef6c730
commit 1e137768ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 649 additions and 479 deletions

View file

@ -0,0 +1,25 @@
prompt-template \
-p pulsar://localhost:6650 \
--system-prompt 'You are a {{attitude}}, you are called {{name}}' \
--global-term \
'name=Craig' \
'attitude=LOUD, SHOUTY ANNOYING BOT' \
--prompt \
'question={{question}}' \
'french-question={{question}}' \
"analyze=Find the name and age in this text, and output a JSON structure containing just the name and age fields: {{description}}. Don't add markup, just output the raw JSON object." \
"graph-query=Study the following knowledge graph, and then answer the question.\\n\nGraph:\\n{% for edge in knowledge %}({{edge.0}})-[{{edge.1}}]->({{edge.2}})\\n{%endfor%}\\nQuestion:\\n{{question}}" \
"extract-definition=Analyse the text provided, and then return a list of terms and definitions. The output should be a JSON array, each item in the array is an object with fields 'term' and 'definition'.Don't add markup, just output the raw JSON object. Here is the text:\\n{{text}}" \
--prompt-response-type \
'question=text' \
'analyze=json' \
'graph-query=text' \
'extract-definition=json' \
--prompt-term \
'question=name:Bonny' \
'french-question=attitude:French-speaking bot' \
--prompt-schema \
'analyze={ "type" : "object", "properties" : { "age": { "type" : "number" }, "name": { "type" : "string" } } }' \
'extract-definition={ "type": "array", "items": { "type": "object", "properties": { "term": { "type": "string" }, "definition": { "type": "string" } }, "required": [ "term", "definition" ] } }'

View file

@ -0,0 +1,95 @@
import ibis
import json
from jsonschema import validate
import re
from trustgraph.clients.llm_client import LlmClient
class PromptConfiguration:
def __init__(self, system_template, global_terms={}, prompts={}):
self.system_template = system_template
self.global_terms = global_terms
self.prompts = prompts
class Prompt:
def __init__(self, template, response_type = "text", terms=None, schema=None):
self.template = template
self.response_type = response_type
self.terms = terms
self.schema = schema
class PromptManager:
def __init__(self, llm, config):
self.llm = llm
self.config = config
self.terms = config.global_terms
self.prompts = config.prompts
try:
self.system_template = ibis.Template(config.system_template)
except:
raise RuntimeError("Error in system template")
self.templates = {}
for k, v in self.prompts.items():
try:
self.templates[k] = ibis.Template(v.template)
except:
raise RuntimeError(f"Error in template: {k}")
if v.terms is None:
v.terms = {}
def parse_json(self, text):
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# If no delimiters, assume the entire output is JSON
json_str = text.strip()
return json.loads(json_str)
def invoke(self, id, input):
if id not in self.prompts:
raise RuntimeError("ID invalid")
terms = self.terms | self.prompts[id].terms | input
resp_type = self.prompts[id].response_type
prompt = {
"system": self.system_template.render(terms),
"prompt": self.templates[id].render(terms)
}
resp = self.llm.request(**prompt)
print(resp, flush=True)
if resp_type == "text":
return resp
if resp_type != "json":
raise RuntimeError(f"Response type {resp_type} not known")
try:
obj = self.parse_json(resp)
except:
raise RuntimeError("JSON parse fail")
print(obj, flush=True)
if self.prompts[id].schema:
try:
print(self.prompts[id].schema)
validate(instance=obj, schema=self.prompts[id].schema)
except Exception as e:
raise RuntimeError(f"Schema validation fail: {e}")
return obj

View file

@ -1,47 +0,0 @@
def to_relationships(template, text):
return template.format(text=text)
def to_definitions(template, text):
return template.format(text=text)
def to_topics(template, text):
return template.format(text=text)
def to_rows(template, schema, text):
field_schema = [
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
for f in schema.fields
]
field_schema = "\n".join(field_schema)
return template.format(schema=schema, text=text)
schema = f"""Object name: {schema.name}
Description: {schema.description}
Fields:
{schema}"""
prompt = f""""""
return prompt
def get_cypher(kg):
sg2 = []
for f in kg:
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
return kg
def to_kg_query(template, query, kg):
cypher = get_cypher(kg)
return template.format(query=query, graph=cypher)
def to_document_query(template, query, docs):
docs = "\n\n".join(docs)
return template.format(query=query, documents=docs)

View file

@ -16,8 +16,7 @@ from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_rows
from . prompts import to_kg_query, to_document_query, to_topics
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
module = ".".join(__name__.split(".")[1:-1])
@ -29,6 +28,82 @@ class Processor(ConsumerProducer):
def __init__(self, **params):
prompt_base = {}
# Parsing the prompt information to the prompt configuration
# structure
prompt_arg = params.get("prompt", [])
if prompt_arg:
for p in prompt_arg:
toks = p.split("=", 1)
if len(toks) < 2:
raise RuntimeError(f"Prompt string not well-formed: {p}")
prompt_base[toks[0]] = {
"template": toks[1]
}
prompt_response_type_arg = params.get("prompt_response_type", [])
if prompt_response_type_arg:
for p in prompt_response_type_arg:
toks = p.split("=", 1)
if len(toks) < 2:
raise RuntimeError(f"Response type not well-formed: {p}")
if toks[0] not in prompt_base:
raise RuntimeError(f"Response-type, {toks[0]} not known")
prompt_base[toks[0]]["response_type"] = toks[1]
prompt_schema_arg = params.get("prompt_schema", [])
if prompt_schema_arg:
for p in prompt_schema_arg:
toks = p.split("=", 1)
if len(toks) < 2:
raise RuntimeError(f"Schema arg not well-formed: {p}")
if toks[0] not in prompt_base:
raise RuntimeError(f"Schema, {toks[0]} not known")
try:
prompt_base[toks[0]]["schema"] = json.loads(toks[1])
except:
raise RuntimeError(f"Failed to parse JSON schema: {p}")
prompt_term_arg = params.get("prompt_term", [])
if prompt_term_arg:
for p in prompt_term_arg:
toks = p.split("=", 1)
if len(toks) < 2:
raise RuntimeError(f"Term arg not well-formed: {p}")
if toks[0] not in prompt_base:
raise RuntimeError(f"Term, {toks[0]} not known")
kvtoks = toks[1].split(":", 1)
if len(kvtoks) < 2:
raise RuntimeError(f"Term not well-formed: {toks[1]}")
k, v = kvtoks
if "terms" not in prompt_base[toks[0]]:
prompt_base[toks[0]]["terms"] = {}
prompt_base[toks[0]]["terms"][k] = v
global_terms = {}
global_term_arg = params.get("global_term", [])
if global_term_arg:
for t in global_term_arg:
toks = t.split("=", 1)
if len(toks) < 2:
raise RuntimeError(f"Global term arg not well-formed: {t}")
global_terms[toks[0]] = toks[1]
print(global_terms)
prompts = {
k: Prompt(**v)
for k, v in prompt_base.items()
}
prompt_configuration = PromptConfiguration(
system_template = params.get("system_prompt", ""),
global_terms = global_terms,
prompts = prompts
)
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
@ -64,23 +139,21 @@ class Processor(ConsumerProducer):
pulsar_host = self.pulsar_host
)
self.definition_template = definition_template
self.topic_template = topic_template
self.relationship_template = relationship_template
self.rows_template = rows_template
self.knowledge_query_template = knowledge_query_template
self.document_query_template = document_query_template
# System prompt hack
class Llm:
def __init__(self, llm):
self.llm = llm
def request(self, system, prompt):
print(system)
print(prompt, flush=True)
return self.llm.request(system + "\n\n" + prompt)
def parse_json(self, text):
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# If no delimiters, assume the entire output is JSON
json_str = text.strip()
self.llm = Llm(self.llm)
return json.loads(json_str)
self.manager = PromptManager(
llm = self.llm,
config = prompt_configuration,
)
def handle(self, msg):
@ -90,88 +163,52 @@ class Processor(ConsumerProducer):
id = msg.properties()["id"]
kind = v.kind
print(f"Handling kind {kind}...", flush=True)
if kind == "extract-definitions":
self.handle_extract_definitions(id, v)
return
elif kind == "extract-topics":
self.handle_extract_topics(id, v)
return
elif kind == "extract-relationships":
self.handle_extract_relationships(id, v)
return
elif kind == "extract-rows":
self.handle_extract_rows(id, v)
return
elif kind == "kg-prompt":
self.handle_kg_prompt(id, v)
return
elif kind == "document-prompt":
self.handle_document_prompt(id, v)
return
else:
print("Invalid kind.", flush=True)
return
def handle_extract_definitions(self, id, v):
kind = v.id
try:
prompt = to_definitions(self.definition_template, v.chunk)
print(v.terms)
ans = self.llm.request(prompt)
input = {
k: json.loads(v)
for k, v in v.terms.items()
}
print(f"Handling kind {kind}...", flush=True)
print(input, flush=True)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
resp = self.manager.invoke(kind, input)
output = []
if isinstance(resp, str):
for defn in defs:
print("Send text response...", flush=True)
print(resp, flush=True)
try:
e = defn["entity"]
d = defn["definition"]
r = PromptResponse(
text=resp,
object=None,
error=None,
)
if e == "": continue
if e is None: continue
if d == "": continue
if d is None: continue
self.producer.send(r, properties={"id": id})
output.append(
Definition(
name=e, definition=d
)
)
return
except:
print("definition fields missing, ignored", flush=True)
else:
print("Send response...", flush=True)
r = PromptResponse(definitions=output, error=None)
self.producer.send(r, properties={"id": id})
print("Send object response...", flush=True)
print(json.dumps(resp, indent=4), flush=True)
print("Done.", flush=True)
r = PromptResponse(
text=None,
object=json.dumps(resp),
error=None,
)
self.producer.send(r, properties={"id": id})
return
except Exception as e:
print(f"Exception: {e}")
@ -188,122 +225,6 @@ class Processor(ConsumerProducer):
self.producer.send(r, properties={"id": id})
def handle_extract_topics(self, id, v):
try:
prompt = to_topics(self.topic_template, v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
e = defn["topic"]
d = defn["definition"]
if e == "": continue
if e is None: continue
if d == "": continue
if d is None: continue
output.append(
Topic(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(topics=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_extract_relationships(self, id, v):
try:
prompt = to_relationships(self.relationship_template, v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
s = defn["subject"]
p = defn["predicate"]
o = defn["object"]
o_entity = defn["object-entity"]
if s == "": continue
if s is None: continue
if p == "": continue
if p is None: continue
if o == "": continue
if o is None: continue
if o_entity == "" or o_entity is None:
o_entity = False
output.append(
Relationship(
s = s,
p = p,
o = o,
o_entity = o_entity,
)
)
except Exception as e:
print("relationship fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(relationships=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
@ -320,147 +241,6 @@ class Processor(ConsumerProducer):
self.producer.send(r, properties={"id": id})
def handle_extract_rows(self, id, v):
try:
fields = v.row_schema.fields
prompt = to_rows(self.rows_template, v.row_schema, v.chunk)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
# Silently ignore JSON parse error
try:
objs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
objs = []
output = []
for obj in objs:
try:
row = {}
for f in fields:
if f.name not in obj:
print(f"Object ignored, missing field {f.name}")
row = {}
break
row[f.name] = obj[f.name]
if row == {}:
continue
output.append(row)
except Exception as e:
print("row fields missing, ignored", flush=True)
for row in output:
print(row)
print("Send response...", flush=True)
r = PromptResponse(rows=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_kg_prompt(self, id, v):
try:
prompt = to_kg_query(self.knowledge_query_template, v.query, v.kg)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_document_prompt(self, id, v):
try:
prompt = to_document_query(
self.document_query_template, v.query, v.documents
)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
@staticmethod
def add_args(parser):
@ -482,39 +262,33 @@ class Processor(ConsumerProducer):
)
parser.add_argument(
'--definition-template',
required=True,
help=f'Definition extraction template',
'--prompt', nargs='*',
help=f'Prompt template form id=template',
)
parser.add_argument(
'--topic-template',
required=True,
help=f'Topic extraction template',
'--prompt-response-type', nargs='*',
help=f'Prompt response type, form id=json|text',
)
parser.add_argument(
'--rows-template',
required=True,
help=f'Rows extraction template',
'--prompt-term', nargs='*',
help=f'Prompt response type, form id=key:value',
)
parser.add_argument(
'--relationship-template',
required=True,
help=f'Relationship extraction template',
'--prompt-schema', nargs='*',
help=f'Prompt response schema, form id=schema',
)
parser.add_argument(
'--knowledge-query-template',
required=True,
help=f'Knowledge query template',
'--system-prompt',
help=f'System prompt template',
)
parser.add_argument(
'--document-query-template',
required=True,
help=f'Document query template',
'--global-term', nargs='+',
help=f'Global term, form key:value'
)
def run():