mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-28 00:35:13 +02:00
Feature/configure flows (#345)
- Keeps processing in different flows separate so that data can go to different stores / collections etc. - Potentially supports different processing flows - Tidies the processing API with common base-classes for e.g. LLMs, and automatic configuration of 'clients' to use the right queue names in a flow
This commit is contained in:
parent
a06a814a41
commit
a9197d11ee
125 changed files with 3751 additions and 2628 deletions
|
|
@ -8,12 +8,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class AgentManager:
|
||||
|
||||
def __init__(self, context, tools, additional_context=None):
|
||||
self.context = context
|
||||
def __init__(self, tools, additional_context=None):
|
||||
self.tools = tools
|
||||
self.additional_context = additional_context
|
||||
|
||||
def reason(self, question, history):
|
||||
async def reason(self, question, history, context):
|
||||
|
||||
tools = self.tools
|
||||
|
||||
|
|
@ -56,10 +55,7 @@ class AgentManager:
|
|||
|
||||
logger.info(f"prompt: {variables}")
|
||||
|
||||
obj = self.context.prompt.request(
|
||||
"agent-react",
|
||||
variables
|
||||
)
|
||||
obj = await context("prompt-request").agent_react(variables)
|
||||
|
||||
print(json.dumps(obj, indent=4), flush=True)
|
||||
|
||||
|
|
@ -85,9 +81,13 @@ class AgentManager:
|
|||
|
||||
return a
|
||||
|
||||
async def react(self, question, history, think, observe):
|
||||
async def react(self, question, history, think, observe, context):
|
||||
|
||||
act = self.reason(question, history)
|
||||
act = await self.reason(
|
||||
question = question,
|
||||
history = history,
|
||||
context = context,
|
||||
)
|
||||
logger.info(f"act: {act}")
|
||||
|
||||
if isinstance(act, Final):
|
||||
|
|
@ -104,7 +104,12 @@ class AgentManager:
|
|||
else:
|
||||
raise RuntimeError(f"No action for {act.name}!")
|
||||
|
||||
resp = action.implementation.invoke(**act.arguments)
|
||||
print("TOOL>>>", act)
|
||||
resp = await action.implementation(context).invoke(
|
||||
**act.arguments
|
||||
)
|
||||
|
||||
print("RSETUL", resp)
|
||||
|
||||
resp = resp.strip()
|
||||
|
||||
|
|
|
|||
|
|
@ -6,103 +6,68 @@ import json
|
|||
import re
|
||||
import sys
|
||||
|
||||
from pulsar.schema import JsonSchema
|
||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec
|
||||
|
||||
from ... base import ConsumerProducer
|
||||
from ... schema import Error
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep
|
||||
from ... schema import agent_request_queue, agent_response_queue
|
||||
from ... schema import prompt_request_queue as pr_request_queue
|
||||
from ... schema import prompt_response_queue as pr_response_queue
|
||||
from ... schema import graph_rag_request_queue as gr_request_queue
|
||||
from ... schema import graph_rag_response_queue as gr_response_queue
|
||||
from ... clients.prompt_client import PromptClient
|
||||
from ... clients.llm_client import LlmClient
|
||||
from ... clients.graph_rag_client import GraphRagClient
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl
|
||||
from . agent_manager import AgentManager
|
||||
|
||||
from . types import Final, Action, Tool, Argument
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
default_ident = "agent-manager"
|
||||
default_max_iterations = 10
|
||||
|
||||
default_input_queue = agent_request_queue
|
||||
default_output_queue = agent_response_queue
|
||||
default_subscriber = module
|
||||
default_max_iterations = 15
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(AgentService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
|
||||
self.max_iterations = int(
|
||||
params.get("max_iterations", default_max_iterations)
|
||||
)
|
||||
|
||||
tools = {}
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
prompt_request_queue = params.get(
|
||||
"prompt_request_queue", pr_request_queue
|
||||
)
|
||||
prompt_response_queue = params.get(
|
||||
"prompt_response_queue", pr_response_queue
|
||||
)
|
||||
graph_rag_request_queue = params.get(
|
||||
"graph_rag_request_queue", gr_request_queue
|
||||
)
|
||||
graph_rag_response_queue = params.get(
|
||||
"graph_rag_response_queue", gr_response_queue
|
||||
)
|
||||
|
||||
self.config_key = params.get("config_type", "agent")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": AgentRequest,
|
||||
"output_schema": AgentResponse,
|
||||
"prompt_request_queue": prompt_request_queue,
|
||||
"prompt_response_queue": prompt_response_queue,
|
||||
"graph_rag_request_queue": gr_request_queue,
|
||||
"graph_rag_response_queue": gr_response_queue,
|
||||
"id": id,
|
||||
"max_iterations": self.max_iterations,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=prompt_request_queue,
|
||||
output_queue=prompt_response_queue,
|
||||
pulsar_host = self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
)
|
||||
|
||||
self.graph_rag = GraphRagClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=graph_rag_request_queue,
|
||||
output_queue=graph_rag_response_queue,
|
||||
pulsar_host = self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
)
|
||||
|
||||
# Need to be able to feed requests to myself
|
||||
self.recursive_input = self.client.create_producer(
|
||||
topic=input_queue,
|
||||
schema=JsonSchema(AgentRequest),
|
||||
)
|
||||
|
||||
self.agent = AgentManager(
|
||||
context=self,
|
||||
tools=[],
|
||||
additional_context="",
|
||||
)
|
||||
|
||||
async def on_config(self, version, config):
|
||||
self.config_handlers.append(self.on_tools_config)
|
||||
|
||||
self.register_specification(
|
||||
TextCompletionClientSpec(
|
||||
request_name = "text-completion-request",
|
||||
response_name = "text-completion-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
GraphRagClientSpec(
|
||||
request_name = "graph-rag-request",
|
||||
response_name = "graph-rag-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
print("Loading configuration version", version)
|
||||
|
||||
|
|
@ -138,9 +103,9 @@ class Processor(ConsumerProducer):
|
|||
impl_id = data.get("type")
|
||||
|
||||
if impl_id == "knowledge-query":
|
||||
impl = KnowledgeQueryImpl(self)
|
||||
impl = KnowledgeQueryImpl
|
||||
elif impl_id == "text-completion":
|
||||
impl = TextCompletionImpl(self)
|
||||
impl = TextCompletionImpl
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tool-kind {impl_id} not known"
|
||||
|
|
@ -155,7 +120,6 @@ class Processor(ConsumerProducer):
|
|||
)
|
||||
|
||||
self.agent = AgentManager(
|
||||
context=self,
|
||||
tools=tools,
|
||||
additional_context=additional
|
||||
)
|
||||
|
|
@ -164,19 +128,14 @@ class Processor(ConsumerProducer):
|
|||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
print("on_tools_config Exception:", e, flush=True)
|
||||
print("Configuration reload failed", flush=True)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def agent_request(self, request, respond, next, flow):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
if v.history:
|
||||
if request.history:
|
||||
history = [
|
||||
Action(
|
||||
thought=h.thought,
|
||||
|
|
@ -184,12 +143,12 @@ class Processor(ConsumerProducer):
|
|||
arguments=h.arguments,
|
||||
observation=h.observation
|
||||
)
|
||||
for h in v.history
|
||||
for h in request.history
|
||||
]
|
||||
else:
|
||||
history = []
|
||||
|
||||
print(f"Question: {v.question}", flush=True)
|
||||
print(f"Question: {request.question}", flush=True)
|
||||
|
||||
if len(history) >= self.max_iterations:
|
||||
raise RuntimeError("Too many agent iterations")
|
||||
|
|
@ -207,7 +166,7 @@ class Processor(ConsumerProducer):
|
|||
observation=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
await respond(r)
|
||||
|
||||
async def observe(x):
|
||||
|
||||
|
|
@ -220,15 +179,21 @@ class Processor(ConsumerProducer):
|
|||
observation=x,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
await respond(r)
|
||||
|
||||
act = await self.agent.react(v.question, history, think, observe)
|
||||
act = await self.agent.react(
|
||||
question = request.question,
|
||||
history = history,
|
||||
think = think,
|
||||
observe = observe,
|
||||
context = flow,
|
||||
)
|
||||
|
||||
print(f"Action: {act}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
if isinstance(act, Final):
|
||||
|
||||
if type(act) == Final:
|
||||
print("Send final response...", flush=True)
|
||||
|
||||
r = AgentResponse(
|
||||
answer=act.final,
|
||||
|
|
@ -236,18 +201,20 @@ class Processor(ConsumerProducer):
|
|||
thought=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
await respond(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
return
|
||||
|
||||
print("Send next...", flush=True)
|
||||
|
||||
history.append(act)
|
||||
|
||||
r = AgentRequest(
|
||||
question=v.question,
|
||||
plan=v.plan,
|
||||
state=v.state,
|
||||
question=request.question,
|
||||
plan=request.plan,
|
||||
state=request.state,
|
||||
history=[
|
||||
AgentStep(
|
||||
thought=h.thought,
|
||||
|
|
@ -259,7 +226,7 @@ class Processor(ConsumerProducer):
|
|||
]
|
||||
)
|
||||
|
||||
self.recursive_input.send(r, properties={"id": id})
|
||||
await next(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
|
|
@ -267,7 +234,7 @@ class Processor(ConsumerProducer):
|
|||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
print(f"agent_request Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
|
|
@ -279,39 +246,12 @@ class Processor(ConsumerProducer):
|
|||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
await respond(r)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=pr_request_queue,
|
||||
help=f'Prompt request queue (default: {pr_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=pr_response_queue,
|
||||
help=f'Prompt response queue (default: {pr_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-rag-request-queue',
|
||||
default=gr_request_queue,
|
||||
help=f'Graph RAG request queue (default: {gr_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-rag-response-queue',
|
||||
default=gr_response_queue,
|
||||
help=f'Graph RAG response queue (default: {gr_response_queue})',
|
||||
)
|
||||
AgentService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-iterations',
|
||||
|
|
@ -327,5 +267,5 @@ class Processor(ConsumerProducer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,16 +4,22 @@
|
|||
class KnowledgeQueryImpl:
|
||||
def __init__(self, context):
|
||||
self.context = context
|
||||
def invoke(self, **arguments):
|
||||
return self.context.graph_rag.request(arguments.get("question"))
|
||||
async def invoke(self, **arguments):
|
||||
client = self.context("graph-rag-request")
|
||||
print("Graph RAG question...", flush=True)
|
||||
return await client.rag(
|
||||
arguments.get("question")
|
||||
)
|
||||
|
||||
# This tool implementation knows how to do text completion. This uses
|
||||
# the prompt service, rather than talking to TextCompletion directly.
|
||||
class TextCompletionImpl:
|
||||
def __init__(self, context):
|
||||
self.context = context
|
||||
def invoke(self, **arguments):
|
||||
return self.context.prompt.request(
|
||||
"question", { "question": arguments.get("question") }
|
||||
async def invoke(self, **arguments):
|
||||
client = self.context("prompt-request")
|
||||
print("Prompt question...", flush=True)
|
||||
return await client.question(
|
||||
arguments.get("question")
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue