mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
Update to enable knowledge extraction using the agent framework (#439)
* Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.
This commit is contained in:
parent
1fe4ed5226
commit
d83e4e3d59
30 changed files with 3192 additions and 799 deletions
6
trustgraph-flow/scripts/kg-extract-agent
Executable file
6
trustgraph-flow/scripts/kg-extract-agent
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.extract.kg.agent import run
|
||||
|
||||
run()
|
||||
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.prompt.generic import run
|
||||
|
||||
run()
|
||||
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.prompt.template import run
|
||||
from trustgraph.prompt.template import run
|
||||
|
||||
run()
|
||||
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ setuptools.setup(
|
|||
"scripts/graph-rag",
|
||||
"scripts/kg-extract-definitions",
|
||||
"scripts/kg-extract-relationships",
|
||||
"scripts/kg-extract-topics",
|
||||
"scripts/kg-extract-agent",
|
||||
"scripts/kg-store",
|
||||
"scripts/kg-manager",
|
||||
"scripts/librarian",
|
||||
|
|
@ -106,7 +106,6 @@ setuptools.setup(
|
|||
"scripts/oe-write-milvus",
|
||||
"scripts/pdf-decoder",
|
||||
"scripts/pdf-ocr-mistral",
|
||||
"scripts/prompt-generic",
|
||||
"scripts/prompt-template",
|
||||
"scripts/rows-write-cassandra",
|
||||
"scripts/run-processing",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
|
||||
from . types import Action, Final
|
||||
|
||||
|
|
@ -12,6 +13,155 @@ class AgentManager:
|
|||
self.tools = tools
|
||||
self.additional_context = additional_context
|
||||
|
||||
def parse_react_response(self, text):
|
||||
"""Parse text-based ReAct response format.
|
||||
|
||||
Expected format:
|
||||
Thought: [reasoning about what to do next]
|
||||
Action: [tool_name]
|
||||
Args: {
|
||||
"param": "value"
|
||||
}
|
||||
|
||||
OR
|
||||
|
||||
Thought: [reasoning about the final answer]
|
||||
Final Answer: [the answer]
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Expected string response, got {type(text)}")
|
||||
|
||||
# Remove any markdown code blocks that might wrap the response
|
||||
text = re.sub(r'^```[^\n]*\n', '', text.strip())
|
||||
text = re.sub(r'\n```$', '', text.strip())
|
||||
|
||||
lines = text.strip().split('\n')
|
||||
|
||||
thought = None
|
||||
action = None
|
||||
args = None
|
||||
final_answer = None
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# Parse Thought
|
||||
if line.startswith("Thought:"):
|
||||
thought = line[8:].strip()
|
||||
# Handle multi-line thoughts
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
next_line = lines[i].strip()
|
||||
if next_line.startswith(("Action:", "Final Answer:", "Args:")):
|
||||
break
|
||||
thought += " " + next_line
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Parse Final Answer
|
||||
if line.startswith("Final Answer:"):
|
||||
final_answer = line[13:].strip()
|
||||
# Handle multi-line final answers (including JSON)
|
||||
i += 1
|
||||
|
||||
# Check if the answer might be JSON
|
||||
if final_answer.startswith('{') or (i < len(lines) and lines[i].strip().startswith('{')):
|
||||
# Collect potential JSON answer
|
||||
json_text = final_answer if final_answer.startswith('{') else ""
|
||||
brace_count = json_text.count('{') - json_text.count('}')
|
||||
|
||||
while i < len(lines) and (brace_count > 0 or not json_text):
|
||||
current_line = lines[i].strip()
|
||||
if current_line.startswith(("Thought:", "Action:")) and brace_count == 0:
|
||||
break
|
||||
json_text += ("\n" if json_text else "") + current_line
|
||||
brace_count += current_line.count('{') - current_line.count('}')
|
||||
i += 1
|
||||
|
||||
# Try to parse as JSON
|
||||
# try:
|
||||
# final_answer = json.loads(json_text)
|
||||
# except json.JSONDecodeError:
|
||||
# # Not valid JSON, treat as regular text
|
||||
# final_answer = json_text
|
||||
final_answer = json_text
|
||||
else:
|
||||
# Regular text answer
|
||||
while i < len(lines):
|
||||
next_line = lines[i].strip()
|
||||
if next_line.startswith(("Thought:", "Action:")):
|
||||
break
|
||||
final_answer += " " + next_line
|
||||
i += 1
|
||||
|
||||
# If we have a final answer, return Final object
|
||||
return Final(
|
||||
thought=thought or "",
|
||||
final=final_answer
|
||||
)
|
||||
|
||||
# Parse Action
|
||||
if line.startswith("Action:"):
|
||||
action = line[7:].strip()
|
||||
|
||||
# Parse Args
|
||||
if line.startswith("Args:"):
|
||||
# Check if JSON starts on the same line
|
||||
args_on_same_line = line[5:].strip()
|
||||
if args_on_same_line:
|
||||
args_text = args_on_same_line
|
||||
brace_count = args_on_same_line.count('{') - args_on_same_line.count('}')
|
||||
else:
|
||||
args_text = ""
|
||||
brace_count = 0
|
||||
|
||||
# Collect all lines that form the JSON arguments
|
||||
i += 1
|
||||
started = bool(args_on_same_line and '{' in args_on_same_line)
|
||||
|
||||
while i < len(lines) and (not started or brace_count > 0):
|
||||
current_line = lines[i]
|
||||
args_text += ("\n" if args_text else "") + current_line
|
||||
|
||||
# Count braces to determine when JSON is complete
|
||||
for char in current_line:
|
||||
if char == '{':
|
||||
brace_count += 1
|
||||
started = True
|
||||
elif char == '}':
|
||||
brace_count -= 1
|
||||
|
||||
# If we've started and braces are balanced, we're done
|
||||
if started and brace_count == 0:
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
# Parse the JSON arguments
|
||||
try:
|
||||
args = json.loads(args_text.strip())
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse JSON arguments: {args_text}")
|
||||
raise ValueError(f"Invalid JSON in Args: {e}")
|
||||
|
||||
i += 1
|
||||
|
||||
# If we have an action, return Action object
|
||||
if action:
|
||||
return Action(
|
||||
thought=thought or "",
|
||||
name=action,
|
||||
arguments=args or {},
|
||||
observation=""
|
||||
)
|
||||
|
||||
# If we only have a thought but no action or final answer
|
||||
if thought and not action and not final_answer:
|
||||
raise ValueError(f"Response has thought but no action or final answer: {text}")
|
||||
|
||||
raise ValueError(f"Could not parse response: {text}")
|
||||
|
||||
async def reason(self, question, history, context):
|
||||
|
||||
print(f"calling reason: {question}", flush=True)
|
||||
|
|
@ -62,31 +212,23 @@ class AgentManager:
|
|||
|
||||
logger.info(f"prompt: {variables}")
|
||||
|
||||
obj = await context("prompt-request").agent_react(variables)
|
||||
# Get text response from prompt service
|
||||
response_text = await context("prompt-request").agent_react(variables)
|
||||
|
||||
print(json.dumps(obj, indent=4), flush=True)
|
||||
print(f"Response text:\n{response_text}", flush=True)
|
||||
|
||||
logger.info(f"response: {obj}")
|
||||
logger.info(f"response: {response_text}")
|
||||
|
||||
if obj.get("final-answer"):
|
||||
|
||||
a = Final(
|
||||
thought = obj.get("thought"),
|
||||
final = obj.get("final-answer"),
|
||||
)
|
||||
|
||||
return a
|
||||
|
||||
else:
|
||||
|
||||
a = Action(
|
||||
thought = obj.get("thought"),
|
||||
name = obj.get("action"),
|
||||
arguments = obj.get("arguments"),
|
||||
observation = ""
|
||||
)
|
||||
|
||||
return a
|
||||
# Parse the text response
|
||||
try:
|
||||
result = self.parse_react_response(response_text)
|
||||
logger.info(f"Parsed result: {result}")
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse response: {e}")
|
||||
# Try to provide a helpful error message
|
||||
logger.error(f"Response was: {response_text}")
|
||||
raise RuntimeError(f"Failed to parse agent response: {e}")
|
||||
|
||||
async def react(self, question, history, think, observe, context):
|
||||
|
||||
|
|
@ -120,7 +262,11 @@ class AgentManager:
|
|||
**act.arguments
|
||||
)
|
||||
|
||||
resp = resp.strip()
|
||||
if isinstance(resp, str):
|
||||
resp = resp.strip()
|
||||
else:
|
||||
resp = str(resp)
|
||||
resp = resp.strip()
|
||||
|
||||
logger.info(f"resp: {resp}")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ import json
|
|||
import re
|
||||
import sys
|
||||
import functools
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec
|
||||
|
|
@ -221,6 +225,11 @@ class Processor(AgentService):
|
|||
|
||||
print("Send final response...", flush=True)
|
||||
|
||||
if isinstance(act.final, str):
|
||||
f = act.final
|
||||
else:
|
||||
f = json.dumps(act.final)
|
||||
|
||||
r = AgentResponse(
|
||||
answer=act.final,
|
||||
error=None,
|
||||
|
|
@ -292,6 +301,5 @@ class Processor(AgentService):
|
|||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
1
trustgraph-flow/trustgraph/extract/kg/agent/__init__.py
Normal file
1
trustgraph-flow/trustgraph/extract/kg/agent/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .extract import *
|
||||
4
trustgraph-flow/trustgraph/extract/kg/agent/__main__.py
Normal file
4
trustgraph-flow/trustgraph/extract/kg/agent/__main__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .extract import Processor
|
||||
|
||||
if __name__ == "__main__":
|
||||
Processor.run()
|
||||
336
trustgraph-flow/trustgraph/extract/kg/agent/extract.py
Normal file
336
trustgraph-flow/trustgraph/extract/kg/agent/extract.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
import re
|
||||
import json
|
||||
import urllib.parse
|
||||
|
||||
from ....schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from ....schema import EntityContext, EntityContexts
|
||||
|
||||
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION
|
||||
|
||||
from ....base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ....base import AgentClientSpec
|
||||
|
||||
from ....template import PromptManager
|
||||
|
||||
default_ident = "kg-extract-agent"
|
||||
default_concurrency = 1
|
||||
default_template_id = "agent-kg-extract"
|
||||
default_config_type = "prompt"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
template_id = params.get("template-id", default_template_id)
|
||||
config_key = params.get("config-type", default_config_type)
|
||||
|
||||
super().__init__(**params | {
|
||||
"id": id,
|
||||
"template-id": template_id,
|
||||
"config-type": config_key,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
self.concurrency = concurrency
|
||||
self.template_id = template_id
|
||||
self.config_key = config_key
|
||||
|
||||
self.register_config_handler(self.on_prompt_config)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = Chunk,
|
||||
handler = self.on_message,
|
||||
concurrency = self.concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
AgentClientSpec(
|
||||
request_name = "agent-request",
|
||||
response_name = "agent-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="triples",
|
||||
schema=Triples,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="entity-contexts",
|
||||
schema=EntityContexts,
|
||||
)
|
||||
)
|
||||
|
||||
# Null configuration, should reload quickly
|
||||
self.manager = PromptManager()
|
||||
|
||||
async def on_prompt_config(self, config, version):
|
||||
|
||||
print("Loading configuration version", version, flush=True)
|
||||
|
||||
if self.config_key not in config:
|
||||
print(f"No key {self.config_key} in config", flush=True)
|
||||
return
|
||||
|
||||
config = config[self.config_key]
|
||||
|
||||
try:
|
||||
|
||||
self.manager.load_config(config)
|
||||
|
||||
print("Prompt configuration reloaded.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
print("Configuration reload failed", flush=True)
|
||||
|
||||
def to_uri(self, text):
|
||||
return TRUSTGRAPH_ENTITIES + urllib.parse.quote(text)
|
||||
|
||||
async def emit_triples(self, pub, metadata, triples):
|
||||
tpls = Triples(
|
||||
metadata = Metadata(
|
||||
id = metadata.id,
|
||||
metadata = [],
|
||||
user = metadata.user,
|
||||
collection = metadata.collection,
|
||||
),
|
||||
triples = triples,
|
||||
)
|
||||
|
||||
await pub.send(tpls)
|
||||
|
||||
async def emit_entity_contexts(self, pub, metadata, entity_contexts):
|
||||
ecs = EntityContexts(
|
||||
metadata = Metadata(
|
||||
id = metadata.id,
|
||||
metadata = [],
|
||||
user = metadata.user,
|
||||
collection = metadata.collection,
|
||||
),
|
||||
entities = entity_contexts,
|
||||
)
|
||||
|
||||
await pub.send(ecs)
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Extract chunk text
|
||||
chunk_text = v.chunk.decode('utf-8')
|
||||
|
||||
print("Got chunk", flush=True)
|
||||
|
||||
prompt = self.manager.render(
|
||||
self.template_id,
|
||||
{
|
||||
"text": chunk_text
|
||||
}
|
||||
)
|
||||
|
||||
print("Prompt:", prompt, flush=True)
|
||||
|
||||
async def handle(response):
|
||||
|
||||
print("Response:", response, flush=True)
|
||||
|
||||
if response.error is not None:
|
||||
if response.error.message:
|
||||
raise RuntimeError(str(response.error.message))
|
||||
else:
|
||||
raise RuntimeError(str(response.error))
|
||||
|
||||
if response.answer is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# Send to agent API
|
||||
agent_response = await flow("agent-request").invoke(
|
||||
recipient = handle,
|
||||
question = prompt
|
||||
)
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
extraction_data = self.parse_json(agent_response)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON response from agent: {e}")
|
||||
|
||||
# Process extraction data
|
||||
triples, entity_contexts = self.process_extraction_data(
|
||||
extraction_data, v.metadata
|
||||
)
|
||||
|
||||
# Put document metadata into triples
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
|
||||
# Emit outputs
|
||||
if triples:
|
||||
await self.emit_triples(flow("triples"), v.metadata, triples)
|
||||
|
||||
if entity_contexts:
|
||||
await self.emit_entity_contexts(
|
||||
flow("entity-contexts"),
|
||||
v.metadata,
|
||||
entity_contexts
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing chunk: {e}", flush=True)
|
||||
raise
|
||||
|
||||
def process_extraction_data(self, data, metadata):
|
||||
"""Process combined extraction data to generate triples and entity contexts"""
|
||||
triples = []
|
||||
entity_contexts = []
|
||||
|
||||
# Process definitions
|
||||
for defn in data.get("definitions", []):
|
||||
|
||||
entity_uri = self.to_uri(defn["entity"])
|
||||
|
||||
# Add entity label
|
||||
triples.append(Triple(
|
||||
s = Value(value=entity_uri, is_uri=True),
|
||||
p = Value(value=RDF_LABEL, is_uri=True),
|
||||
o = Value(value=defn["entity"], is_uri=False),
|
||||
))
|
||||
|
||||
# Add definition
|
||||
triples.append(Triple(
|
||||
s = Value(value=entity_uri, is_uri=True),
|
||||
p = Value(value=DEFINITION, is_uri=True),
|
||||
o = Value(value=defn["definition"], is_uri=False),
|
||||
))
|
||||
|
||||
# Add subject-of relationship to document
|
||||
if metadata.id:
|
||||
triples.append(Triple(
|
||||
s = Value(value=entity_uri, is_uri=True),
|
||||
p = Value(value=SUBJECT_OF, is_uri=True),
|
||||
o = Value(value=metadata.id, is_uri=True),
|
||||
))
|
||||
|
||||
# Create entity context for embeddings
|
||||
entity_contexts.append(EntityContext(
|
||||
entity=Value(value=entity_uri, is_uri=True),
|
||||
context=defn["definition"]
|
||||
))
|
||||
|
||||
# Process relationships
|
||||
for rel in data.get("relationships", []):
|
||||
|
||||
subject_uri = self.to_uri(rel["subject"])
|
||||
predicate_uri = self.to_uri(rel["predicate"])
|
||||
|
||||
subject_value = Value(value=subject_uri, is_uri=True)
|
||||
predicate_value = Value(value=predicate_uri, is_uri=True)
|
||||
if data.get("object-entity", False):
|
||||
object_value = Value(value=predicate_uri, is_uri=True)
|
||||
else:
|
||||
object_value = Value(value=predicate_uri, is_uri=False)
|
||||
|
||||
# Add subject and predicate labels
|
||||
triples.append(Triple(
|
||||
s = subject_value,
|
||||
p = Value(value=RDF_LABEL, is_uri=True),
|
||||
o = Value(value=rel["subject"], is_uri=False),
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
s = predicate_value,
|
||||
p = Value(value=RDF_LABEL, is_uri=True),
|
||||
o = Value(value=rel["predicate"], is_uri=False),
|
||||
))
|
||||
|
||||
# Handle object (entity vs literal)
|
||||
if rel.get("object-entity", True):
|
||||
triples.append(Triple(
|
||||
s = object_value,
|
||||
p = Value(value=RDF_LABEL, is_uri=True),
|
||||
o = Value(value=rel["object"], is_uri=True),
|
||||
))
|
||||
|
||||
# Add the main relationship triple
|
||||
triples.append(Triple(
|
||||
s = subject_value,
|
||||
p = predicate_value,
|
||||
o = object_value
|
||||
))
|
||||
|
||||
# Add subject-of relationships to document
|
||||
if metadata.id:
|
||||
triples.append(Triple(
|
||||
s = subject_value,
|
||||
p = Value(value=SUBJECT_OF, is_uri=True),
|
||||
o = Value(value=metadata.id, is_uri=True),
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
s = predicate_value,
|
||||
p = Value(value=SUBJECT_OF, is_uri=True),
|
||||
o = Value(value=metadata.id, is_uri=True),
|
||||
))
|
||||
|
||||
if rel.get("object-entity", True):
|
||||
triples.append(Triple(
|
||||
s = object_value,
|
||||
p = Value(value=SUBJECT_OF, is_uri=True),
|
||||
o = Value(value=metadata.id, is_uri=True),
|
||||
))
|
||||
|
||||
return triples, entity_contexts
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--template-id",
|
||||
type=str,
|
||||
default=default_template_id,
|
||||
help="Template ID to use for agent extraction"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default="prompt",
|
||||
help=f'Configuration key for prompts (default: prompt)',
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
|
||||
def to_relationships(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.
|
||||
|
||||
Read the provided text. You will model the text as an information network for a RDF knowledge graph in JSON.
|
||||
|
||||
Information Network Rules:
|
||||
- An information network has subjects connected by predicates to objects.
|
||||
- A subject is a named-entity or a conceptual topic.
|
||||
- One subject can have many predicates and objects.
|
||||
- An object is a property or attribute of a subject.
|
||||
- A subject can be connected by a predicate to another subject.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Obey the information network rules.
|
||||
- Do not return special characters.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of JSON objects with keys "subject", "predicate", "object", and "object-entity".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"subject": string, "predicate": string, "object": string, "object-entity": boolean}}]
|
||||
```
|
||||
|
||||
- The key "object-entity" is TRUE only if the "object" is a subject.
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_topics(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only topics that are concepts and unique to the provided text.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "topic" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"topic": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_definitions(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify entities and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only entities that are named-entities such as: people, organizations, physical objects, locations, animals, products, commodotities, or substances.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "entity" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"entity": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_rows(schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{field_schema}"""
|
||||
|
||||
prompt = f"""<instructions>
|
||||
Study the following text and derive objects which match the schema provided.
|
||||
|
||||
You must output an array of JSON objects for each object you discover
|
||||
which matches the schema. For each object, output a JSON object whose fields
|
||||
carry the name field specified in the schema.
|
||||
</instructions>
|
||||
|
||||
<schema>
|
||||
{schema}
|
||||
</schema>
|
||||
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
|
||||
<requirements>
|
||||
You will respond only with raw JSON format data. Do not provide
|
||||
explanations. Do not add markdown formatting or headers or prefixes.
|
||||
</requirements>"""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
|
||||
sg2 = []
|
||||
|
||||
for f in kg:
|
||||
|
||||
print(f)
|
||||
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
|
||||
print(sg2)
|
||||
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
|
||||
return kg
|
||||
|
||||
def to_kg_query(query, kg):
|
||||
|
||||
cypher = get_cypher(kg)
|
||||
|
||||
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here's the knowledge statements:
|
||||
{cypher}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_document_query(query, documents):
|
||||
|
||||
documents = "\n\n".join(documents)
|
||||
|
||||
prompt=f"""Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here is the context:
|
||||
{documents}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
|
@ -1,485 +0,0 @@
|
|||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
#
|
||||
# FIXME: This module is broken, it doesn't conform to the prompt API change
|
||||
# made in 0.14, nor the prompt template support.
|
||||
#
|
||||
# It could be made to conform by using prompt-template as a starting
|
||||
# point, and hard-coding all the information.
|
||||
#
|
||||
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_topics
|
||||
from . prompts import to_kg_query, to_document_query, to_rows
|
||||
|
||||
module = "prompt"
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
)
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
await self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
await self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
await self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
await self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
await self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
await self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
async def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
async def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
async def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
async def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
async def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
async def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(v.query, v.documents)
|
||||
|
||||
print("prompt")
|
||||
print(prompt)
|
||||
|
||||
print("Call LLM...")
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
raise RuntimeError("NOT IMPLEMENTED")
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -7,15 +7,15 @@ import asyncio
|
|||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from ...schema import Definition, Relationship, Triple
|
||||
from ...schema import Topic
|
||||
from ...schema import PromptRequest, PromptResponse, Error
|
||||
from ...schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
from .... base import FlowProcessor
|
||||
from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
|
||||
from ...base import FlowProcessor
|
||||
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
|
||||
|
||||
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
|
||||
from ...template import PromptManager
|
||||
|
||||
default_ident = "prompt"
|
||||
default_concurrency = 1
|
||||
|
|
@ -33,6 +33,7 @@ class Processor(FlowProcessor):
|
|||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config-type": self.config_key,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
)
|
||||
|
|
@ -63,9 +64,7 @@ class Processor(FlowProcessor):
|
|||
self.register_config_handler(self.on_prompt_config)
|
||||
|
||||
# Null configuration, should reload quickly
|
||||
self.manager = PromptManager(
|
||||
config = PromptConfiguration("", {}, {})
|
||||
)
|
||||
self.manager = PromptManager()
|
||||
|
||||
async def on_prompt_config(self, config, version):
|
||||
|
||||
|
|
@ -79,34 +78,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
system = json.loads(config["system"])
|
||||
ix = json.loads(config["template-index"])
|
||||
|
||||
prompts = {}
|
||||
|
||||
for k in ix:
|
||||
|
||||
pc = config[f"template.{k}"]
|
||||
data = json.loads(pc)
|
||||
|
||||
prompt = data.get("prompt")
|
||||
rtype = data.get("response-type", "text")
|
||||
schema = data.get("schema", None)
|
||||
|
||||
prompts[k] = Prompt(
|
||||
template = prompt,
|
||||
response_type = rtype,
|
||||
schema = schema,
|
||||
terms = {}
|
||||
)
|
||||
|
||||
self.manager = PromptManager(
|
||||
PromptConfiguration(
|
||||
system,
|
||||
{},
|
||||
prompts
|
||||
)
|
||||
)
|
||||
self.manager.load_config(config)
|
||||
|
||||
print("Prompt configuration reloaded.", flush=True)
|
||||
|
||||
|
|
@ -230,14 +202,14 @@ class Processor(FlowProcessor):
|
|||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default="prompt",
|
||||
help=f'Configuration key for prompts (default: prompt)',
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
3
trustgraph-flow/trustgraph/template/__init__.py
Normal file
3
trustgraph-flow/trustgraph/template/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from .prompt_manager import *
|
||||
|
||||
|
|
@ -19,14 +19,51 @@ class Prompt:
|
|||
|
||||
class PromptManager:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.terms = config.global_terms
|
||||
def __init__(self):
|
||||
|
||||
self.prompts = config.prompts
|
||||
self.load_config({})
|
||||
|
||||
def load_config(self, config):
|
||||
|
||||
try:
|
||||
self.system_template = ibis.Template(config.system_template)
|
||||
system = json.loads(config["system"])
|
||||
except:
|
||||
system = "Be helpful."
|
||||
|
||||
try:
|
||||
ix = json.loads(config["template-index"])
|
||||
except:
|
||||
ix = []
|
||||
|
||||
prompts = {}
|
||||
|
||||
for k in ix:
|
||||
|
||||
pc = config[f"template.{k}"]
|
||||
data = json.loads(pc)
|
||||
|
||||
prompt = data.get("prompt")
|
||||
rtype = data.get("response-type", "text")
|
||||
schema = data.get("schema", None)
|
||||
|
||||
prompts[k] = Prompt(
|
||||
template = prompt,
|
||||
response_type = rtype,
|
||||
schema = schema,
|
||||
terms = {}
|
||||
)
|
||||
|
||||
self.config = PromptConfiguration(
|
||||
system,
|
||||
{},
|
||||
prompts
|
||||
)
|
||||
|
||||
self.terms = self.config.global_terms
|
||||
self.prompts = self.config.prompts
|
||||
|
||||
try:
|
||||
self.system_template = ibis.Template(self.config.system_template)
|
||||
except:
|
||||
raise RuntimeError("Error in system template")
|
||||
|
||||
|
|
@ -34,8 +71,8 @@ class PromptManager:
|
|||
for k, v in self.prompts.items():
|
||||
try:
|
||||
self.templates[k] = ibis.Template(v.template)
|
||||
except:
|
||||
raise RuntimeError(f"Error in template: {k}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in template: {k}: {e}")
|
||||
|
||||
if v.terms is None:
|
||||
v.terms = {}
|
||||
|
|
@ -51,9 +88,7 @@ class PromptManager:
|
|||
|
||||
return json.loads(json_str)
|
||||
|
||||
async def invoke(self, id, input, llm):
|
||||
|
||||
print("Invoke...", flush=True)
|
||||
def render(self, id, input):
|
||||
|
||||
if id not in self.prompts:
|
||||
raise RuntimeError("ID invalid")
|
||||
|
|
@ -62,9 +97,19 @@ class PromptManager:
|
|||
|
||||
resp_type = self.prompts[id].response_type
|
||||
|
||||
return self.templates[id].render(terms)
|
||||
|
||||
async def invoke(self, id, input, llm):
|
||||
|
||||
print("Invoke...", flush=True)
|
||||
|
||||
terms = self.terms | self.prompts[id].terms | input
|
||||
|
||||
resp_type = self.prompts[id].response_type
|
||||
|
||||
prompt = {
|
||||
"system": self.system_template.render(terms),
|
||||
"prompt": self.templates[id].render(terms)
|
||||
"prompt": self.render(id, input)
|
||||
}
|
||||
|
||||
resp = await llm(**prompt)
|
||||
Loading…
Add table
Add a link
Reference in a new issue