mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-16 18:35:14 +02:00
Update to enable knowledge extraction using the agent framework (#439)
* Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.
This commit is contained in:
parent
1fe4ed5226
commit
d83e4e3d59
30 changed files with 3192 additions and 799 deletions
3
trustgraph-flow/trustgraph/template/__init__.py
Normal file
3
trustgraph-flow/trustgraph/template/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from .prompt_manager import *
|
||||
|
||||
137
trustgraph-flow/trustgraph/template/prompt_manager.py
Normal file
137
trustgraph-flow/trustgraph/template/prompt_manager.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
|
||||
import ibis
|
||||
import json
|
||||
from jsonschema import validate
|
||||
import re
|
||||
|
||||
class PromptConfiguration:
|
||||
def __init__(self, system_template, global_terms={}, prompts={}):
|
||||
self.system_template = system_template
|
||||
self.global_terms = global_terms
|
||||
self.prompts = prompts
|
||||
|
||||
class Prompt:
|
||||
def __init__(self, template, response_type = "text", terms=None, schema=None):
|
||||
self.template = template
|
||||
self.response_type = response_type
|
||||
self.terms = terms
|
||||
self.schema = schema
|
||||
|
||||
class PromptManager:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self.load_config({})
|
||||
|
||||
def load_config(self, config):
|
||||
|
||||
try:
|
||||
system = json.loads(config["system"])
|
||||
except:
|
||||
system = "Be helpful."
|
||||
|
||||
try:
|
||||
ix = json.loads(config["template-index"])
|
||||
except:
|
||||
ix = []
|
||||
|
||||
prompts = {}
|
||||
|
||||
for k in ix:
|
||||
|
||||
pc = config[f"template.{k}"]
|
||||
data = json.loads(pc)
|
||||
|
||||
prompt = data.get("prompt")
|
||||
rtype = data.get("response-type", "text")
|
||||
schema = data.get("schema", None)
|
||||
|
||||
prompts[k] = Prompt(
|
||||
template = prompt,
|
||||
response_type = rtype,
|
||||
schema = schema,
|
||||
terms = {}
|
||||
)
|
||||
|
||||
self.config = PromptConfiguration(
|
||||
system,
|
||||
{},
|
||||
prompts
|
||||
)
|
||||
|
||||
self.terms = self.config.global_terms
|
||||
self.prompts = self.config.prompts
|
||||
|
||||
try:
|
||||
self.system_template = ibis.Template(self.config.system_template)
|
||||
except:
|
||||
raise RuntimeError("Error in system template")
|
||||
|
||||
self.templates = {}
|
||||
for k, v in self.prompts.items():
|
||||
try:
|
||||
self.templates[k] = ibis.Template(v.template)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in template: {k}: {e}")
|
||||
|
||||
if v.terms is None:
|
||||
v.terms = {}
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def render(self, id, input):
|
||||
|
||||
if id not in self.prompts:
|
||||
raise RuntimeError("ID invalid")
|
||||
|
||||
terms = self.terms | self.prompts[id].terms | input
|
||||
|
||||
resp_type = self.prompts[id].response_type
|
||||
|
||||
return self.templates[id].render(terms)
|
||||
|
||||
async def invoke(self, id, input, llm):
|
||||
|
||||
print("Invoke...", flush=True)
|
||||
|
||||
terms = self.terms | self.prompts[id].terms | input
|
||||
|
||||
resp_type = self.prompts[id].response_type
|
||||
|
||||
prompt = {
|
||||
"system": self.system_template.render(terms),
|
||||
"prompt": self.render(id, input)
|
||||
}
|
||||
|
||||
resp = await llm(**prompt)
|
||||
|
||||
if resp_type == "text":
|
||||
return resp
|
||||
|
||||
if resp_type != "json":
|
||||
raise RuntimeError(f"Response type {resp_type} not known")
|
||||
|
||||
try:
|
||||
obj = self.parse_json(resp)
|
||||
except:
|
||||
print("Parse fail:", resp, flush=True)
|
||||
raise RuntimeError("JSON parse fail")
|
||||
|
||||
if self.prompts[id].schema:
|
||||
try:
|
||||
validate(instance=obj, schema=self.prompts[id].schema)
|
||||
print("Validated", flush=True)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Schema validation fail: {e}")
|
||||
|
||||
return obj
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue