Update to enable knowledge extraction using the agent framework (#439)

* Implement KG extraction agent (kg-extract-agent)

* Using ReAct framework (agent-manager-react)
 
* ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure.
 
* Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.
This commit is contained in:
cybermaggedon 2025-07-21 14:31:57 +01:00 committed by GitHub
parent 1fe4ed5226
commit d83e4e3d59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 3192 additions and 799 deletions

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.extract.kg.agent import run
run()

View file

@ -1,6 +0,0 @@
#!/usr/bin/env python3
from trustgraph.model.prompt.generic import run
run()

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
from trustgraph.model.prompt.template import run
from trustgraph.prompt.template import run
run()

View file

@ -96,7 +96,7 @@ setuptools.setup(
"scripts/graph-rag",
"scripts/kg-extract-definitions",
"scripts/kg-extract-relationships",
"scripts/kg-extract-topics",
"scripts/kg-extract-agent",
"scripts/kg-store",
"scripts/kg-manager",
"scripts/librarian",
@ -106,7 +106,6 @@ setuptools.setup(
"scripts/oe-write-milvus",
"scripts/pdf-decoder",
"scripts/pdf-ocr-mistral",
"scripts/prompt-generic",
"scripts/prompt-template",
"scripts/rows-write-cassandra",
"scripts/run-processing",

View file

@ -1,6 +1,7 @@
import logging
import json
import re
from . types import Action, Final
@ -12,6 +13,155 @@ class AgentManager:
self.tools = tools
self.additional_context = additional_context
def parse_react_response(self, text):
"""Parse text-based ReAct response format.
Expected format:
Thought: [reasoning about what to do next]
Action: [tool_name]
Args: {
"param": "value"
}
OR
Thought: [reasoning about the final answer]
Final Answer: [the answer]
"""
if not isinstance(text, str):
raise ValueError(f"Expected string response, got {type(text)}")
# Remove any markdown code blocks that might wrap the response
text = re.sub(r'^```[^\n]*\n', '', text.strip())
text = re.sub(r'\n```$', '', text.strip())
lines = text.strip().split('\n')
thought = None
action = None
args = None
final_answer = None
i = 0
while i < len(lines):
line = lines[i].strip()
# Parse Thought
if line.startswith("Thought:"):
thought = line[8:].strip()
# Handle multi-line thoughts
i += 1
while i < len(lines):
next_line = lines[i].strip()
if next_line.startswith(("Action:", "Final Answer:", "Args:")):
break
thought += " " + next_line
i += 1
continue
# Parse Final Answer
if line.startswith("Final Answer:"):
final_answer = line[13:].strip()
# Handle multi-line final answers (including JSON)
i += 1
# Check if the answer might be JSON
if final_answer.startswith('{') or (i < len(lines) and lines[i].strip().startswith('{')):
# Collect potential JSON answer
json_text = final_answer if final_answer.startswith('{') else ""
brace_count = json_text.count('{') - json_text.count('}')
while i < len(lines) and (brace_count > 0 or not json_text):
current_line = lines[i].strip()
if current_line.startswith(("Thought:", "Action:")) and brace_count == 0:
break
json_text += ("\n" if json_text else "") + current_line
brace_count += current_line.count('{') - current_line.count('}')
i += 1
# Try to parse as JSON
# try:
# final_answer = json.loads(json_text)
# except json.JSONDecodeError:
# # Not valid JSON, treat as regular text
# final_answer = json_text
final_answer = json_text
else:
# Regular text answer
while i < len(lines):
next_line = lines[i].strip()
if next_line.startswith(("Thought:", "Action:")):
break
final_answer += " " + next_line
i += 1
# If we have a final answer, return Final object
return Final(
thought=thought or "",
final=final_answer
)
# Parse Action
if line.startswith("Action:"):
action = line[7:].strip()
# Parse Args
if line.startswith("Args:"):
# Check if JSON starts on the same line
args_on_same_line = line[5:].strip()
if args_on_same_line:
args_text = args_on_same_line
brace_count = args_on_same_line.count('{') - args_on_same_line.count('}')
else:
args_text = ""
brace_count = 0
# Collect all lines that form the JSON arguments
i += 1
started = bool(args_on_same_line and '{' in args_on_same_line)
while i < len(lines) and (not started or brace_count > 0):
current_line = lines[i]
args_text += ("\n" if args_text else "") + current_line
# Count braces to determine when JSON is complete
for char in current_line:
if char == '{':
brace_count += 1
started = True
elif char == '}':
brace_count -= 1
# If we've started and braces are balanced, we're done
if started and brace_count == 0:
break
i += 1
# Parse the JSON arguments
try:
args = json.loads(args_text.strip())
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON arguments: {args_text}")
raise ValueError(f"Invalid JSON in Args: {e}")
i += 1
# If we have an action, return Action object
if action:
return Action(
thought=thought or "",
name=action,
arguments=args or {},
observation=""
)
# If we only have a thought but no action or final answer
if thought and not action and not final_answer:
raise ValueError(f"Response has thought but no action or final answer: {text}")
raise ValueError(f"Could not parse response: {text}")
async def reason(self, question, history, context):
print(f"calling reason: {question}", flush=True)
@ -62,31 +212,23 @@ class AgentManager:
logger.info(f"prompt: {variables}")
obj = await context("prompt-request").agent_react(variables)
# Get text response from prompt service
response_text = await context("prompt-request").agent_react(variables)
print(json.dumps(obj, indent=4), flush=True)
print(f"Response text:\n{response_text}", flush=True)
logger.info(f"response: {obj}")
logger.info(f"response: {response_text}")
if obj.get("final-answer"):
a = Final(
thought = obj.get("thought"),
final = obj.get("final-answer"),
)
return a
else:
a = Action(
thought = obj.get("thought"),
name = obj.get("action"),
arguments = obj.get("arguments"),
observation = ""
)
return a
# Parse the text response
try:
result = self.parse_react_response(response_text)
logger.info(f"Parsed result: {result}")
return result
except ValueError as e:
logger.error(f"Failed to parse response: {e}")
# Try to provide a helpful error message
logger.error(f"Response was: {response_text}")
raise RuntimeError(f"Failed to parse agent response: {e}")
async def react(self, question, history, think, observe, context):
@ -120,7 +262,11 @@ class AgentManager:
**act.arguments
)
resp = resp.strip()
if isinstance(resp, str):
resp = resp.strip()
else:
resp = str(resp)
resp = resp.strip()
logger.info(f"resp: {resp}")

View file

@ -6,6 +6,10 @@ import json
import re
import sys
import functools
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec
@ -221,6 +225,11 @@ class Processor(AgentService):
print("Send final response...", flush=True)
if isinstance(act.final, str):
f = act.final
else:
f = json.dumps(act.final)
r = AgentResponse(
answer=act.final,
error=None,
@ -292,6 +301,5 @@ class Processor(AgentService):
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1 @@
from .extract import *

View file

@ -0,0 +1,4 @@
from .extract import Processor
if __name__ == "__main__":
Processor.run()

View file

@ -0,0 +1,336 @@
import re
import json
import urllib.parse
from ....schema import Chunk, Triple, Triples, Metadata, Value
from ....schema import EntityContext, EntityContexts
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION
from ....base import FlowProcessor, ConsumerSpec, ProducerSpec
from ....base import AgentClientSpec
from ....template import PromptManager
default_ident = "kg-extract-agent"
default_concurrency = 1
default_template_id = "agent-kg-extract"
default_config_type = "prompt"
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
template_id = params.get("template-id", default_template_id)
config_key = params.get("config-type", default_config_type)
super().__init__(**params | {
"id": id,
"template-id": template_id,
"config-type": config_key,
"concurrency": concurrency,
})
self.concurrency = concurrency
self.template_id = template_id
self.config_key = config_key
self.register_config_handler(self.on_prompt_config)
self.register_specification(
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message,
concurrency = self.concurrency,
)
)
self.register_specification(
AgentClientSpec(
request_name = "agent-request",
response_name = "agent-response",
)
)
self.register_specification(
ProducerSpec(
name="triples",
schema=Triples,
)
)
self.register_specification(
ProducerSpec(
name="entity-contexts",
schema=EntityContexts,
)
)
# Null configuration, should reload quickly
self.manager = PromptManager()
async def on_prompt_config(self, config, version):
print("Loading configuration version", version, flush=True)
if self.config_key not in config:
print(f"No key {self.config_key} in config", flush=True)
return
config = config[self.config_key]
try:
self.manager.load_config(config)
print("Prompt configuration reloaded.", flush=True)
except Exception as e:
print("Exception:", e, flush=True)
print("Configuration reload failed", flush=True)
def to_uri(self, text):
return TRUSTGRAPH_ENTITIES + urllib.parse.quote(text)
async def emit_triples(self, pub, metadata, triples):
tpls = Triples(
metadata = Metadata(
id = metadata.id,
metadata = [],
user = metadata.user,
collection = metadata.collection,
),
triples = triples,
)
await pub.send(tpls)
async def emit_entity_contexts(self, pub, metadata, entity_contexts):
ecs = EntityContexts(
metadata = Metadata(
id = metadata.id,
metadata = [],
user = metadata.user,
collection = metadata.collection,
),
entities = entity_contexts,
)
await pub.send(ecs)
def parse_json(self, text):
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# If no delimiters, assume the entire output is JSON
json_str = text.strip()
return json.loads(json_str)
async def on_message(self, msg, consumer, flow):
try:
v = msg.value()
# Extract chunk text
chunk_text = v.chunk.decode('utf-8')
print("Got chunk", flush=True)
prompt = self.manager.render(
self.template_id,
{
"text": chunk_text
}
)
print("Prompt:", prompt, flush=True)
async def handle(response):
print("Response:", response, flush=True)
if response.error is not None:
if response.error.message:
raise RuntimeError(str(response.error.message))
else:
raise RuntimeError(str(response.error))
if response.answer is not None:
return True
else:
return False
# Send to agent API
agent_response = await flow("agent-request").invoke(
recipient = handle,
question = prompt
)
# Parse JSON response
try:
extraction_data = self.parse_json(agent_response)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON response from agent: {e}")
# Process extraction data
triples, entity_contexts = self.process_extraction_data(
extraction_data, v.metadata
)
# Put document metadata into triples
for t in v.metadata.metadata:
triples.append(t)
# Emit outputs
if triples:
await self.emit_triples(flow("triples"), v.metadata, triples)
if entity_contexts:
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
entity_contexts
)
except Exception as e:
print(f"Error processing chunk: {e}", flush=True)
raise
def process_extraction_data(self, data, metadata):
"""Process combined extraction data to generate triples and entity contexts"""
triples = []
entity_contexts = []
# Process definitions
for defn in data.get("definitions", []):
entity_uri = self.to_uri(defn["entity"])
# Add entity label
triples.append(Triple(
s = Value(value=entity_uri, is_uri=True),
p = Value(value=RDF_LABEL, is_uri=True),
o = Value(value=defn["entity"], is_uri=False),
))
# Add definition
triples.append(Triple(
s = Value(value=entity_uri, is_uri=True),
p = Value(value=DEFINITION, is_uri=True),
o = Value(value=defn["definition"], is_uri=False),
))
# Add subject-of relationship to document
if metadata.id:
triples.append(Triple(
s = Value(value=entity_uri, is_uri=True),
p = Value(value=SUBJECT_OF, is_uri=True),
o = Value(value=metadata.id, is_uri=True),
))
# Create entity context for embeddings
entity_contexts.append(EntityContext(
entity=Value(value=entity_uri, is_uri=True),
context=defn["definition"]
))
# Process relationships
for rel in data.get("relationships", []):
subject_uri = self.to_uri(rel["subject"])
predicate_uri = self.to_uri(rel["predicate"])
subject_value = Value(value=subject_uri, is_uri=True)
predicate_value = Value(value=predicate_uri, is_uri=True)
if data.get("object-entity", False):
object_value = Value(value=predicate_uri, is_uri=True)
else:
object_value = Value(value=predicate_uri, is_uri=False)
# Add subject and predicate labels
triples.append(Triple(
s = subject_value,
p = Value(value=RDF_LABEL, is_uri=True),
o = Value(value=rel["subject"], is_uri=False),
))
triples.append(Triple(
s = predicate_value,
p = Value(value=RDF_LABEL, is_uri=True),
o = Value(value=rel["predicate"], is_uri=False),
))
# Handle object (entity vs literal)
if rel.get("object-entity", True):
triples.append(Triple(
s = object_value,
p = Value(value=RDF_LABEL, is_uri=True),
o = Value(value=rel["object"], is_uri=True),
))
# Add the main relationship triple
triples.append(Triple(
s = subject_value,
p = predicate_value,
o = object_value
))
# Add subject-of relationships to document
if metadata.id:
triples.append(Triple(
s = subject_value,
p = Value(value=SUBJECT_OF, is_uri=True),
o = Value(value=metadata.id, is_uri=True),
))
triples.append(Triple(
s = predicate_value,
p = Value(value=SUBJECT_OF, is_uri=True),
o = Value(value=metadata.id, is_uri=True),
))
if rel.get("object-entity", True):
triples.append(Triple(
s = object_value,
p = Value(value=SUBJECT_OF, is_uri=True),
o = Value(value=metadata.id, is_uri=True),
))
return triples, entity_contexts
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
parser.add_argument(
"--template-id",
type=str,
default=default_template_id,
help="Template ID to use for agent extraction"
)
parser.add_argument(
'--config-type',
default="prompt",
help=f'Configuration key for prompts (default: prompt)',
)
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -1,176 +0,0 @@
def to_relationships(text):
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.
Read the provided text. You will model the text as an information network for a RDF knowledge graph in JSON.
Information Network Rules:
- An information network has subjects connected by predicates to objects.
- A subject is a named-entity or a conceptual topic.
- One subject can have many predicates and objects.
- An object is a property or attribute of a subject.
- A subject can be connected by a predicate to another subject.
Reading Instructions:
- Ignore document formatting in the provided text.
- Study the provided text carefully.
Here is the text:
{text}
Response Instructions:
- Obey the information network rules.
- Do not return special characters.
- Respond only with well-formed JSON.
- The JSON response shall be an array of JSON objects with keys "subject", "predicate", "object", and "object-entity".
- The JSON response shall use the following structure:
```json
[{{"subject": string, "predicate": string, "object": string, "object-entity": boolean}}]
```
- The key "object-entity" is TRUE only if the "object" is a subject.
- Do not write any additional text or explanations.
"""
return prompt
def to_topics(text):
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.
Reading Instructions:
- Ignore document formatting in the provided text.
- Study the provided text carefully.
Here is the text:
{text}
Response Instructions:
- Do not respond with special characters.
- Return only topics that are concepts and unique to the provided text.
- Respond only with well-formed JSON.
- The JSON response shall be an array of objects with keys "topic" and "definition".
- The JSON response shall use the following structure:
```json
[{{"topic": string, "definition": string}}]
```
- Do not write any additional text or explanations.
"""
return prompt
def to_definitions(text):
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify entities and their definitions in JSON.
Reading Instructions:
- Ignore document formatting in the provided text.
- Study the provided text carefully.
Here is the text:
{text}
Response Instructions:
- Do not respond with special characters.
- Return only entities that are named-entities such as: people, organizations, physical objects, locations, animals, products, commodotities, or substances.
- Respond only with well-formed JSON.
- The JSON response shall be an array of objects with keys "entity" and "definition".
- The JSON response shall use the following structure:
```json
[{{"entity": string, "definition": string}}]
```
- Do not write any additional text or explanations.
"""
return prompt
def to_rows(schema, text):
field_schema = [
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
for f in schema.fields
]
field_schema = "\n".join(field_schema)
schema = f"""Object name: {schema.name}
Description: {schema.description}
Fields:
{field_schema}"""
prompt = f"""<instructions>
Study the following text and derive objects which match the schema provided.
You must output an array of JSON objects for each object you discover
which matches the schema. For each object, output a JSON object whose fields
carry the name field specified in the schema.
</instructions>
<schema>
{schema}
</schema>
<text>
{text}
</text>
<requirements>
You will respond only with raw JSON format data. Do not provide
explanations. Do not add markdown formatting or headers or prefixes.
</requirements>"""
return prompt
def get_cypher(kg):
sg2 = []
for f in kg:
print(f)
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
print(sg2)
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
return kg
def to_kg_query(query, kg):
cypher = get_cypher(kg)
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
Here's the knowledge statements:
{cypher}
Use only the provided knowledge statements to respond to the following:
{query}
"""
return prompt
def to_document_query(query, documents):
documents = "\n\n".join(documents)
prompt=f"""Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
Here is the context:
{documents}
Use only the provided knowledge statements to respond to the following:
{query}
"""
return prompt

View file

@ -1,485 +0,0 @@
"""
Language service abstracts prompt engineering from LLM.
"""
#
# FIXME: This module is broken, it doesn't conform to the prompt API change
# made in 0.14, nor the prompt template support.
#
# It could be made to conform by using prompt-template as a starting
# point, and hard-coding all the information.
#
import json
import re
from .... schema import Definition, Relationship, Triple
from .... schema import Topic
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_topics
from . prompts import to_kg_query, to_document_query, to_rows
module = "prompt"
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": PromptRequest,
"output_schema": PromptResponse,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
}
)
self.llm = LlmClient(
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
def parse_json(self, text):
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# If no delimiters, assume the entire output is JSON
json_str = text.strip()
return json.loads(json_str)
async def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
kind = v.kind
print(f"Handling kind {kind}...", flush=True)
if kind == "extract-definitions":
await self.handle_extract_definitions(id, v)
return
elif kind == "extract-topics":
await self.handle_extract_topics(id, v)
return
elif kind == "extract-relationships":
await self.handle_extract_relationships(id, v)
return
elif kind == "extract-rows":
await self.handle_extract_rows(id, v)
return
elif kind == "kg-prompt":
await self.handle_kg_prompt(id, v)
return
elif kind == "document-prompt":
await self.handle_document_prompt(id, v)
return
else:
print("Invalid kind.", flush=True)
return
async def handle_extract_definitions(self, id, v):
try:
prompt = to_definitions(v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
e = defn["entity"]
d = defn["definition"]
if e == "": continue
if e is None: continue
if d == "": continue
if d is None: continue
output.append(
Definition(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(definitions=output, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
async def handle_extract_topics(self, id, v):
try:
prompt = to_topics(v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
e = defn["topic"]
d = defn["definition"]
if e == "": continue
if e is None: continue
if d == "": continue
if d is None: continue
output.append(
Topic(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(topics=output, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
async def handle_extract_relationships(self, id, v):
try:
prompt = to_relationships(v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
s = defn["subject"]
p = defn["predicate"]
o = defn["object"]
o_entity = defn["object-entity"]
if s == "": continue
if s is None: continue
if p == "": continue
if p is None: continue
if o == "": continue
if o is None: continue
if o_entity == "" or o_entity is None:
o_entity = False
output.append(
Relationship(
s = s,
p = p,
o = o,
o_entity = o_entity,
)
)
except Exception as e:
print("relationship fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(relationships=output, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
async def handle_extract_rows(self, id, v):
try:
fields = v.row_schema.fields
prompt = to_rows(v.row_schema, v.chunk)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
# Silently ignore JSON parse error
try:
objs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
objs = []
output = []
for obj in objs:
try:
row = {}
for f in fields:
if f.name not in obj:
print(f"Object ignored, missing field {f.name}")
row = {}
break
row[f.name] = obj[f.name]
if row == {}:
continue
output.append(row)
except Exception as e:
print("row fields missing, ignored", flush=True)
for row in output:
print(row)
print("Send response...", flush=True)
r = PromptResponse(rows=output, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
async def handle_kg_prompt(self, id, v):
try:
prompt = to_kg_query(v.query, v.kg)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
async def handle_document_prompt(self, id, v):
try:
prompt = to_document_query(v.query, v.documents)
print("prompt")
print(prompt)
print("Call LLM...")
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
def run():
raise RuntimeError("NOT IMPLEMENTED")
Processor.launch(module, __doc__)

View file

@ -1,3 +0,0 @@
from . service import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -7,15 +7,15 @@ import asyncio
import json
import re
from .... schema import Definition, Relationship, Triple
from .... schema import Topic
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
from ...schema import Definition, Relationship, Triple
from ...schema import Topic
from ...schema import PromptRequest, PromptResponse, Error
from ...schema import TextCompletionRequest, TextCompletionResponse
from .... base import FlowProcessor
from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
from ...base import FlowProcessor
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
from ...template import PromptManager
default_ident = "prompt"
default_concurrency = 1
@ -33,6 +33,7 @@ class Processor(FlowProcessor):
super(Processor, self).__init__(
**params | {
"id": id,
"config-type": self.config_key,
"concurrency": concurrency,
}
)
@ -63,9 +64,7 @@ class Processor(FlowProcessor):
self.register_config_handler(self.on_prompt_config)
# Null configuration, should reload quickly
self.manager = PromptManager(
config = PromptConfiguration("", {}, {})
)
self.manager = PromptManager()
async def on_prompt_config(self, config, version):
@ -79,34 +78,7 @@ class Processor(FlowProcessor):
try:
system = json.loads(config["system"])
ix = json.loads(config["template-index"])
prompts = {}
for k in ix:
pc = config[f"template.{k}"]
data = json.loads(pc)
prompt = data.get("prompt")
rtype = data.get("response-type", "text")
schema = data.get("schema", None)
prompts[k] = Prompt(
template = prompt,
response_type = rtype,
schema = schema,
terms = {}
)
self.manager = PromptManager(
PromptConfiguration(
system,
{},
prompts
)
)
self.manager.load_config(config)
print("Prompt configuration reloaded.", flush=True)
@ -230,14 +202,14 @@ class Processor(FlowProcessor):
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)
parser.add_argument(
'--config-type',
default="prompt",
help=f'Configuration key for prompts (default: prompt)',
)
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,3 @@
from .prompt_manager import *

View file

@ -19,14 +19,51 @@ class Prompt:
class PromptManager:
def __init__(self, config):
self.config = config
self.terms = config.global_terms
def __init__(self):
self.prompts = config.prompts
self.load_config({})
def load_config(self, config):
try:
self.system_template = ibis.Template(config.system_template)
system = json.loads(config["system"])
except:
system = "Be helpful."
try:
ix = json.loads(config["template-index"])
except:
ix = []
prompts = {}
for k in ix:
pc = config[f"template.{k}"]
data = json.loads(pc)
prompt = data.get("prompt")
rtype = data.get("response-type", "text")
schema = data.get("schema", None)
prompts[k] = Prompt(
template = prompt,
response_type = rtype,
schema = schema,
terms = {}
)
self.config = PromptConfiguration(
system,
{},
prompts
)
self.terms = self.config.global_terms
self.prompts = self.config.prompts
try:
self.system_template = ibis.Template(self.config.system_template)
except:
raise RuntimeError("Error in system template")
@ -34,8 +71,8 @@ class PromptManager:
for k, v in self.prompts.items():
try:
self.templates[k] = ibis.Template(v.template)
except:
raise RuntimeError(f"Error in template: {k}")
except Exception as e:
raise RuntimeError(f"Error in template: {k}: {e}")
if v.terms is None:
v.terms = {}
@ -51,9 +88,7 @@ class PromptManager:
return json.loads(json_str)
async def invoke(self, id, input, llm):
print("Invoke...", flush=True)
def render(self, id, input):
if id not in self.prompts:
raise RuntimeError("ID invalid")
@ -62,9 +97,19 @@ class PromptManager:
resp_type = self.prompts[id].response_type
return self.templates[id].render(terms)
async def invoke(self, id, input, llm):
print("Invoke...", flush=True)
terms = self.terms | self.prompts[id].terms | input
resp_type = self.prompts[id].response_type
prompt = {
"system": self.system_template.render(terms),
"prompt": self.templates[id].render(terms)
"prompt": self.render(id, input)
}
resp = await llm(**prompt)