trustgraph/trustgraph-flow/trustgraph/agent/orchestrator/service.py
cybermaggedon 2bcf375103
Wire message_id on all answer chunks, fix DAG structure (#748)
Wire message_id on all answer chunks, fix DAG structure message_id:
- Add message_id to AgentAnswer dataclass and propagate in
  socket_client._parse_chunk
- Wire message_id into answer callbacks and send_final_response
  for all three patterns (react, plan-then-execute, supervisor)
- Supervisor decomposition thought and synthesis answer chunks
  now carry message_id

DAG structure fixes:
- Observation derives from sub-trace Synthesis (not Analysis)
  when a tool produces a sub-trace; tracked via
  last_sub_explain_uri on context
- Subagent sessions derive from parent's Decomposition via
  parent_uri on agent_session_triples
- Findings derive from subagent Conclusions (not Decomposition)
- Synthesis derives from all findings (multiple wasDerivedFrom)
  ensuring single terminal node
- agent_synthesis_triples accepts list of parent URIs
- Explainability chain walker follows from sub-trace terminal
  to find downstream Observation

Emit Analysis before tool execution:
- Add on_action callback to react() in agent_manager.py, called
  after reason() but before tool invocation
- Orchestrator and old service emit Analysis+ToolUse triples via
  on_action so sub-traces appear after their parent in the stream
2026-04-01 13:27:41 +01:00

592 lines
21 KiB
Python

"""
Agent orchestrator service — multi-pattern drop-in replacement for
agent-manager-react.
Uses the same service identity and Pulsar queues. Adds meta-routing
to select between ReactPattern, PlanThenExecutePattern, and
SupervisorPattern at runtime.
"""
import asyncio
import base64
import json
import functools
import logging
import uuid
from datetime import datetime
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... base import ProducerSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from trustgraph.provenance import (
agent_session_uri,
GRAPH_RETRIEVAL,
)
from ..react.tools import (
KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl,
StructuredQueryImpl, RowEmbeddingsQueryImpl, ToolServiceImpl,
)
from ..react.agent_manager import AgentManager
from ..tool_filter import validate_tool_config
from ..react.types import Final, Action, Tool, Argument
from . meta_router import MetaRouter
from . pattern_base import PatternBase, UserAwareContext
from . react_pattern import ReactPattern
from . plan_pattern import PlanThenExecutePattern
from . supervisor_pattern import SupervisorPattern
from . aggregator import Aggregator
logger = logging.getLogger(__name__)
default_ident = "agent-manager"
default_max_iterations = 10
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(AgentService):
def __init__(self, **params):
id = params.get("id")
self.max_iterations = int(
params.get("max_iterations", default_max_iterations)
)
self.config_key = params.get("config_type", "agent")
super(Processor, self).__init__(
**params | {
"id": id,
"max_iterations": self.max_iterations,
"config_type": self.config_key,
}
)
self.agent = AgentManager(
tools={},
additional_context="",
)
self.tool_service_clients = {}
# Patterns
self.react_pattern = ReactPattern(self)
self.plan_pattern = PlanThenExecutePattern(self)
self.supervisor_pattern = SupervisorPattern(self)
# Aggregator for supervisor fan-in
self.aggregator = Aggregator()
# Meta-router (initialised on first config load)
self.meta_router = None
self.config_handlers.append(self.on_tools_config)
self.register_specification(
TextCompletionClientSpec(
request_name="text-completion-request",
response_name="text-completion-response",
)
)
self.register_specification(
GraphRagClientSpec(
request_name="graph-rag-request",
response_name="graph-rag-response",
)
)
self.register_specification(
PromptClientSpec(
request_name="prompt-request",
response_name="prompt-response",
)
)
self.register_specification(
ToolClientSpec(
request_name="mcp-tool-request",
response_name="mcp-tool-response",
)
)
self.register_specification(
StructuredQueryClientSpec(
request_name="structured-query-request",
response_name="structured-query-response",
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name="embeddings-request",
response_name="embeddings-response",
)
)
self.register_specification(
RowEmbeddingsQueryClientSpec(
request_name="row-embeddings-query-request",
response_name="row-embeddings-query-response",
)
)
# Explainability producer
self.register_specification(
ProducerSpec(
name="explainability",
schema=Triples,
)
)
# Librarian client
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
self.pending_librarian_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, user, content, title=None,
timeout=120):
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: "
f"{response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
def provenance_session_uri(self, session_id):
return agent_session_uri(session_id)
async def on_tools_config(self, config, version):
logger.info(f"Loading configuration version {version}")
try:
tools = {}
# Load tool-service configurations
tool_services = {}
if "tool-service" in config:
for service_id, service_value in config["tool-service"].items():
service_data = json.loads(service_value)
tool_services[service_id] = service_data
logger.debug(f"Loaded tool-service config: {service_id}")
logger.info(
f"Loaded {len(tool_services)} tool-service configurations"
)
# Load tool configurations
if "tool" in config:
for tool_id, tool_value in config["tool"].items():
data = json.loads(tool_value)
impl_id = data.get("type")
name = data.get("name")
if impl_id == "knowledge-query":
impl = functools.partial(
KnowledgeQueryImpl,
collection=data.get("collection"),
)
arguments = KnowledgeQueryImpl.get_arguments()
elif impl_id == "text-completion":
impl = TextCompletionImpl
arguments = TextCompletionImpl.get_arguments()
elif impl_id == "mcp-tool":
config_args = data.get("arguments", [])
arguments = [
Argument(
name=arg.get("name"),
type=arg.get("type"),
description=arg.get("description"),
)
for arg in config_args
]
impl = functools.partial(
McpToolImpl,
mcp_tool_id=data.get("mcp-tool"),
arguments=arguments,
)
elif impl_id == "prompt":
config_args = data.get("arguments", [])
arguments = [
Argument(
name=arg.get("name"),
type=arg.get("type"),
description=arg.get("description"),
)
for arg in config_args
]
impl = functools.partial(
PromptImpl,
template_id=data.get("template"),
arguments=arguments,
)
elif impl_id == "structured-query":
impl = functools.partial(
StructuredQueryImpl,
collection=data.get("collection"),
user=None,
)
arguments = StructuredQueryImpl.get_arguments()
elif impl_id == "row-embeddings-query":
impl = functools.partial(
RowEmbeddingsQueryImpl,
schema_name=data.get("schema-name"),
collection=data.get("collection"),
user=None,
index_name=data.get("index-name"),
limit=int(data.get("limit", 10)),
)
arguments = RowEmbeddingsQueryImpl.get_arguments()
elif impl_id == "tool-service":
service_ref = data.get("service")
if not service_ref:
raise RuntimeError(
f"Tool {name} has type 'tool-service' "
f"but no 'service' reference"
)
if service_ref not in tool_services:
raise RuntimeError(
f"Tool {name} references unknown "
f"tool-service '{service_ref}'"
)
service_config = tool_services[service_ref]
request_queue = service_config.get("request-queue")
response_queue = service_config.get("response-queue")
if not request_queue or not response_queue:
raise RuntimeError(
f"Tool-service '{service_ref}' must define "
f"'request-queue' and 'response-queue'"
)
config_params = service_config.get("config-params", [])
config_values = {}
for param in config_params:
param_name = (
param.get("name")
if isinstance(param, dict) else param
)
if param_name in data:
config_values[param_name] = data[param_name]
elif (
isinstance(param, dict)
and param.get("required", False)
):
raise RuntimeError(
f"Tool {name} missing required config "
f"param '{param_name}'"
)
config_args = data.get("arguments", [])
arguments = [
Argument(
name=arg.get("name"),
type=arg.get("type"),
description=arg.get("description"),
)
for arg in config_args
]
impl = functools.partial(
ToolServiceImpl,
request_queue=request_queue,
response_queue=response_queue,
config_values=config_values,
arguments=arguments,
processor=self,
)
else:
raise RuntimeError(
f"Tool type {impl_id} not known"
)
validate_tool_config(data)
tools[name] = Tool(
name=name,
description=data.get("description"),
implementation=impl,
config=data,
arguments=arguments,
)
# Load additional context from agent config
additional = None
if self.config_key in config:
agent_config = config[self.config_key]
additional = agent_config.get("additional-context", None)
self.agent = AgentManager(
tools=tools,
additional_context=additional,
)
# Re-initialise meta-router with config
self.meta_router = MetaRouter(config=config)
logger.info(f"Loaded {len(tools)} tools")
except Exception as e:
logger.error(
f"on_tools_config Exception: {e}", exc_info=True
)
logger.error("Configuration reload failed")
async def _handle_subagent_completion(self, request, respond, next, flow):
"""Handle a subagent completion by feeding it to the aggregator."""
correlation_id = request.correlation_id
subagent_goal = getattr(request, 'subagent_goal', '')
parent_session_id = getattr(request, 'parent_session_id', '')
# Extract the answer from the completion step
answer_text = ""
for step in request.history:
if getattr(step, 'step_type', '') == 'subagent-completion':
answer_text = step.observation
break
logger.debug(
f"Received subagent completion: "
f"correlation={correlation_id}, goal={subagent_goal}"
)
all_done = self.aggregator.record_completion(
correlation_id, subagent_goal, answer_text
)
if all_done is None:
logger.warning(
f"Unknown correlation_id {correlation_id}"
f"possibly timed out or duplicate"
)
return
# Emit finding provenance for this subagent
template = self.aggregator.get_original_request(correlation_id)
if template and parent_session_id:
entry = self.aggregator.correlations.get(correlation_id)
finding_index = len(entry["results"]) - 1 if entry else 0
collection = getattr(template, 'collection', 'default')
subagent_session_id = getattr(request, 'session_id', '')
await self.supervisor_pattern.emit_finding_triples(
flow, parent_session_id, finding_index,
subagent_goal, answer_text,
template.user, collection,
respond, template.streaming,
subagent_session_id=subagent_session_id,
)
if all_done:
logger.info(
f"All subagents complete for {correlation_id}, "
f"dispatching synthesis"
)
if template is None:
logger.error(
f"No template for correlation {correlation_id}"
)
return
synthesis_request = self.aggregator.build_synthesis_request(
correlation_id,
original_question=template.question,
user=template.user,
collection=getattr(template, 'collection', 'default'),
)
await next(synthesis_request)
async def agent_request(self, request, respond, next, flow):
try:
# Intercept subagent completion messages
correlation_id = getattr(request, 'correlation_id', '')
if correlation_id and request.history:
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in request.history
)
if is_completion:
await self._handle_subagent_completion(
request, respond, next, flow
)
return
pattern = getattr(request, 'pattern', '') or ''
# If no pattern set and this is the first iteration, route
if not pattern and not request.history:
context = UserAwareContext(flow, request.user)
if self.meta_router:
pattern, task_type, framing = await self.meta_router.route(
request.question, context,
)
else:
pattern = "react"
task_type = "general"
framing = ""
# Update request with routing decision
request.pattern = pattern
request.task_type = task_type
request.framing = framing
logger.info(
f"Routed to pattern={pattern}, "
f"task_type={task_type}"
)
# Dispatch to the selected pattern
if pattern == "plan-then-execute":
await self.plan_pattern.iterate(
request, respond, next, flow,
)
elif pattern == "supervisor":
await self.supervisor_pattern.iterate(
request, respond, next, flow,
)
else:
# Default to react
await self.react_pattern.iterate(
request, respond, next, flow,
)
except Exception as e:
logger.error(
f"agent_request Exception: {e}", exc_info=True
)
logger.debug("Send error response...")
error_obj = Error(
type="agent-error",
message=str(e),
)
r = AgentResponse(
chunk_type="error",
content=str(e),
end_of_message=True,
end_of_dialog=True,
error=error_obj,
)
await respond(r)
@staticmethod
def add_args(parser):
AgentService.add_args(parser)
parser.add_argument(
'--max-iterations',
default=default_max_iterations,
help=f'Maximum number of react iterations '
f'(default: {default_max_iterations})',
)
parser.add_argument(
'--config-type',
default="agent",
help='Configuration key for prompts (default: agent)',
)
def run():
Processor.launch(default_ident, __doc__)