mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Add multi-pattern orchestrator with plan-then-execute and supervisor (#739)
Introduce an agent orchestrator service that supports three execution patterns (ReAct, plan-then-execute, supervisor) with LLM-based meta-routing to select the appropriate pattern and task type per request. Update the agent schema to support orchestration fields (correlation, sub-agents, plan steps) and remove legacy response fields (answer, thought, observation).
This commit is contained in:
parent
7af1d60db8
commit
849987f0e6
21 changed files with 3006 additions and 172 deletions
|
|
@ -56,6 +56,7 @@ Homepage = "https://github.com/trustgraph-ai/trustgraph"
|
|||
|
||||
[project.scripts]
|
||||
agent-manager-react = "trustgraph.agent.react:run"
|
||||
agent-orchestrator = "trustgraph.agent.orchestrator:run"
|
||||
api-gateway = "trustgraph.gateway:run"
|
||||
chunker-recursive = "trustgraph.chunking.recursive:run"
|
||||
chunker-token = "trustgraph.chunking.token:run"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
from . service import *
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
157
trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py
Normal file
157
trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""
|
||||
Aggregator — monitors the explainability topic for subagent completions
|
||||
and triggers synthesis when all siblings in a fan-out have completed.
|
||||
|
||||
The aggregator watches for tg:Conclusion triples that carry a
|
||||
correlation_id. When it detects that all expected siblings have
|
||||
completed, it emits a synthesis AgentRequest on the agent request topic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from ... schema import AgentRequest, AgentStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# How long to wait for stalled correlations before giving up (seconds)
|
||||
DEFAULT_TIMEOUT = 300
|
||||
|
||||
|
||||
class Aggregator:
|
||||
"""
|
||||
Tracks in-flight fan-out correlations and triggers synthesis
|
||||
when all subagents complete.
|
||||
|
||||
State is held in-memory; if the process restarts, in-flight
|
||||
correlations are lost (acceptable for v1).
|
||||
"""
|
||||
|
||||
def __init__(self, timeout=DEFAULT_TIMEOUT):
|
||||
self.timeout = timeout
|
||||
|
||||
# correlation_id -> {
|
||||
# "parent_session_id": str,
|
||||
# "expected": int,
|
||||
# "results": {goal: answer},
|
||||
# "request_template": AgentRequest or None,
|
||||
# "created_at": float,
|
||||
# }
|
||||
self.correlations = {}
|
||||
|
||||
def register_fanout(self, correlation_id, parent_session_id,
|
||||
expected_siblings, request_template=None):
|
||||
"""
|
||||
Register a new fan-out. Called by the supervisor after emitting
|
||||
subagent requests.
|
||||
"""
|
||||
self.correlations[correlation_id] = {
|
||||
"parent_session_id": parent_session_id,
|
||||
"expected": expected_siblings,
|
||||
"results": {},
|
||||
"request_template": request_template,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
logger.info(
|
||||
f"Aggregator: registered fan-out {correlation_id}, "
|
||||
f"expecting {expected_siblings} subagents"
|
||||
)
|
||||
|
||||
def record_completion(self, correlation_id, subagent_goal, result):
|
||||
"""
|
||||
Record a subagent completion.
|
||||
|
||||
Returns:
|
||||
True if all siblings are now complete, False otherwise.
|
||||
Returns None if the correlation_id is unknown.
|
||||
"""
|
||||
if correlation_id not in self.correlations:
|
||||
logger.warning(
|
||||
f"Aggregator: unknown correlation_id {correlation_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
entry = self.correlations[correlation_id]
|
||||
entry["results"][subagent_goal] = result
|
||||
|
||||
completed = len(entry["results"])
|
||||
expected = entry["expected"]
|
||||
|
||||
logger.info(
|
||||
f"Aggregator: {correlation_id} — "
|
||||
f"{completed}/{expected} subagents complete"
|
||||
)
|
||||
|
||||
return completed >= expected
|
||||
|
||||
def get_results(self, correlation_id):
|
||||
"""Get all results for a correlation and remove the tracking entry."""
|
||||
entry = self.correlations.pop(correlation_id, None)
|
||||
if entry is None:
|
||||
return None, None, None
|
||||
return (
|
||||
entry["results"],
|
||||
entry["parent_session_id"],
|
||||
entry["request_template"],
|
||||
)
|
||||
|
||||
def build_synthesis_request(self, correlation_id, original_question,
|
||||
user, collection):
|
||||
"""
|
||||
Build the AgentRequest that triggers the synthesis phase.
|
||||
"""
|
||||
results, parent_session_id, template = self.get_results(correlation_id)
|
||||
|
||||
if results is None:
|
||||
raise RuntimeError(
|
||||
f"No results for correlation_id {correlation_id}"
|
||||
)
|
||||
|
||||
# Build history with decompose step + results
|
||||
synthesis_step = AgentStep(
|
||||
thought="All subagents completed",
|
||||
action="aggregate",
|
||||
arguments={},
|
||||
observation=json.dumps(results),
|
||||
step_type="synthesise",
|
||||
subagent_results=results,
|
||||
)
|
||||
|
||||
history = []
|
||||
if template and template.history:
|
||||
history = list(template.history)
|
||||
history.append(synthesis_step)
|
||||
|
||||
return AgentRequest(
|
||||
question=original_question,
|
||||
state="",
|
||||
group=template.group if template else [],
|
||||
history=history,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=template.streaming if template else False,
|
||||
session_id=parent_session_id,
|
||||
conversation_id=template.conversation_id if template else "",
|
||||
pattern="supervisor",
|
||||
task_type=template.task_type if template else "",
|
||||
framing=template.framing if template else "",
|
||||
correlation_id=correlation_id,
|
||||
parent_session_id="",
|
||||
subagent_goal="",
|
||||
expected_siblings=0,
|
||||
)
|
||||
|
||||
def cleanup_stale(self):
|
||||
"""Remove correlations that have timed out."""
|
||||
now = time.time()
|
||||
stale = [
|
||||
cid for cid, entry in self.correlations.items()
|
||||
if now - entry["created_at"] > self.timeout
|
||||
]
|
||||
for cid in stale:
|
||||
logger.warning(f"Aggregator: timing out stale correlation {cid}")
|
||||
self.correlations.pop(cid, None)
|
||||
return stale
|
||||
168
trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py
Normal file
168
trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""
|
||||
MetaRouter — selects the task type and execution pattern for a query.
|
||||
|
||||
Uses the config API to look up available task types and patterns, then
|
||||
asks the LLM to classify the query and select the best pattern.
|
||||
Falls back to ("react", "general", "") if config is empty.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_PATTERN = "react"
|
||||
DEFAULT_TASK_TYPE = "general"
|
||||
DEFAULT_FRAMING = ""
|
||||
|
||||
|
||||
class MetaRouter:
|
||||
|
||||
def __init__(self, config=None):
|
||||
"""
|
||||
Args:
|
||||
config: The full config dict from the config service.
|
||||
May contain "agent-pattern" and "agent-task-type" keys.
|
||||
"""
|
||||
self.patterns = {}
|
||||
self.task_types = {}
|
||||
|
||||
if config:
|
||||
# Load from config API
|
||||
if "agent-pattern" in config:
|
||||
for pid, pval in config["agent-pattern"].items():
|
||||
try:
|
||||
self.patterns[pid] = json.loads(pval)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
self.patterns[pid] = {"name": pid}
|
||||
|
||||
if "agent-task-type" in config:
|
||||
for tid, tval in config["agent-task-type"].items():
|
||||
try:
|
||||
self.task_types[tid] = json.loads(tval)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
self.task_types[tid] = {"name": tid}
|
||||
|
||||
# If config has no patterns/task-types, default to react/general
|
||||
if not self.patterns:
|
||||
self.patterns = {
|
||||
"react": {"name": "react", "description": "Interleaved reasoning and action"},
|
||||
}
|
||||
if not self.task_types:
|
||||
self.task_types = {
|
||||
"general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""},
|
||||
}
|
||||
|
||||
async def identify_task_type(self, question, context):
|
||||
"""
|
||||
Use the LLM to classify the question into one of the known task types.
|
||||
|
||||
Args:
|
||||
question: The user's query.
|
||||
context: UserAwareContext (flow wrapper).
|
||||
|
||||
Returns:
|
||||
(task_type_id, framing) tuple.
|
||||
"""
|
||||
if len(self.task_types) <= 1:
|
||||
tid = next(iter(self.task_types), DEFAULT_TASK_TYPE)
|
||||
framing = self.task_types.get(tid, {}).get("framing", DEFAULT_FRAMING)
|
||||
return tid, framing
|
||||
|
||||
try:
|
||||
client = context("prompt-request")
|
||||
response = await client.prompt(
|
||||
id="task-type-classify",
|
||||
variables={
|
||||
"question": question,
|
||||
"task_types": [
|
||||
{"name": tid, "description": tdata.get("description", tid)}
|
||||
for tid, tdata in self.task_types.items()
|
||||
],
|
||||
},
|
||||
)
|
||||
selected = response.strip().lower().replace('"', '').replace("'", "")
|
||||
|
||||
if selected in self.task_types:
|
||||
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
|
||||
logger.info(f"MetaRouter: identified task type '{selected}'")
|
||||
return selected, framing
|
||||
else:
|
||||
logger.warning(
|
||||
f"MetaRouter: LLM returned unknown task type '{selected}', "
|
||||
f"falling back to '{DEFAULT_TASK_TYPE}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"MetaRouter: task type classification failed: {e}")
|
||||
|
||||
framing = self.task_types.get(DEFAULT_TASK_TYPE, {}).get(
|
||||
"framing", DEFAULT_FRAMING
|
||||
)
|
||||
return DEFAULT_TASK_TYPE, framing
|
||||
|
||||
async def select_pattern(self, question, task_type, context):
|
||||
"""
|
||||
Use the LLM to select the best execution pattern for this task type.
|
||||
|
||||
Args:
|
||||
question: The user's query.
|
||||
task_type: The identified task type ID.
|
||||
context: UserAwareContext (flow wrapper).
|
||||
|
||||
Returns:
|
||||
Pattern ID string.
|
||||
"""
|
||||
task_config = self.task_types.get(task_type, {})
|
||||
valid_patterns = task_config.get("valid_patterns", list(self.patterns.keys()))
|
||||
|
||||
if len(valid_patterns) <= 1:
|
||||
return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN
|
||||
|
||||
try:
|
||||
client = context("prompt-request")
|
||||
response = await client.prompt(
|
||||
id="pattern-select",
|
||||
variables={
|
||||
"question": question,
|
||||
"task_type": task_type,
|
||||
"task_type_description": task_config.get("description", task_type),
|
||||
"patterns": [
|
||||
{"name": pid, "description": self.patterns.get(pid, {}).get("description", pid)}
|
||||
for pid in valid_patterns
|
||||
if pid in self.patterns
|
||||
],
|
||||
},
|
||||
)
|
||||
selected = response.strip().lower().replace('"', '').replace("'", "")
|
||||
|
||||
if selected in valid_patterns:
|
||||
logger.info(f"MetaRouter: selected pattern '{selected}'")
|
||||
return selected
|
||||
else:
|
||||
logger.warning(
|
||||
f"MetaRouter: LLM returned invalid pattern '{selected}', "
|
||||
f"falling back to '{valid_patterns[0]}'"
|
||||
)
|
||||
return valid_patterns[0]
|
||||
except Exception as e:
|
||||
logger.warning(f"MetaRouter: pattern selection failed: {e}")
|
||||
return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN
|
||||
|
||||
async def route(self, question, context):
|
||||
"""
|
||||
Full routing pipeline: identify task type, then select pattern.
|
||||
|
||||
Args:
|
||||
question: The user's query.
|
||||
context: UserAwareContext (flow wrapper).
|
||||
|
||||
Returns:
|
||||
(pattern, task_type, framing) tuple.
|
||||
"""
|
||||
task_type, framing = await self.identify_task_type(question, context)
|
||||
pattern = await self.select_pattern(question, task_type, context)
|
||||
logger.info(
|
||||
f"MetaRouter: route result — "
|
||||
f"pattern={pattern}, task_type={task_type}, framing={framing!r}"
|
||||
)
|
||||
return pattern, task_type, framing
|
||||
428
trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py
Normal file
428
trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py
Normal file
|
|
@ -0,0 +1,428 @@
|
|||
"""
|
||||
Base class for agent patterns.
|
||||
|
||||
Provides shared infrastructure used by all patterns: tool filtering,
|
||||
provenance emission, streaming callbacks, history management, and
|
||||
librarian integration.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from ... schema import Triples, Metadata
|
||||
|
||||
from trustgraph.provenance import (
|
||||
agent_session_uri,
|
||||
agent_iteration_uri,
|
||||
agent_thought_uri,
|
||||
agent_observation_uri,
|
||||
agent_final_uri,
|
||||
agent_session_triples,
|
||||
agent_iteration_triples,
|
||||
agent_final_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
from ..react.types import Action, Final
|
||||
from ..tool_filter import filter_tools_by_group_and_state, get_next_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserAwareContext:
|
||||
"""Wraps flow interface to inject user context for tools that need it."""
|
||||
|
||||
def __init__(self, flow, user):
|
||||
self._flow = flow
|
||||
self._user = user
|
||||
|
||||
def __call__(self, service_name):
|
||||
client = self._flow(service_name)
|
||||
if service_name in (
|
||||
"structured-query-request",
|
||||
"row-embeddings-query-request",
|
||||
):
|
||||
client._current_user = self._user
|
||||
return client
|
||||
|
||||
|
||||
class PatternBase:
|
||||
"""
|
||||
Shared infrastructure for all agent patterns.
|
||||
|
||||
Subclasses implement iterate() to perform one iteration of their
|
||||
pattern-specific logic.
|
||||
"""
|
||||
|
||||
def __init__(self, processor):
|
||||
self.processor = processor
|
||||
|
||||
def filter_tools(self, tools, request):
|
||||
"""Apply group/state filtering to the tool set."""
|
||||
return filter_tools_by_group_and_state(
|
||||
tools=tools,
|
||||
requested_groups=getattr(request, 'group', None),
|
||||
current_state=getattr(request, 'state', None),
|
||||
)
|
||||
|
||||
def make_context(self, flow, user):
|
||||
"""Create a user-aware context wrapper."""
|
||||
return UserAwareContext(flow, user)
|
||||
|
||||
def build_history(self, request):
|
||||
"""Convert AgentStep history into Action objects."""
|
||||
if not request.history:
|
||||
return []
|
||||
return [
|
||||
Action(
|
||||
thought=h.thought,
|
||||
name=h.action,
|
||||
arguments=h.arguments,
|
||||
observation=h.observation,
|
||||
)
|
||||
for h in request.history
|
||||
]
|
||||
|
||||
# ---- Streaming callbacks ------------------------------------------------
|
||||
|
||||
def make_think_callback(self, respond, streaming):
|
||||
"""Create the think callback for streaming/non-streaming."""
|
||||
async def think(x, is_final=False):
|
||||
logger.debug(f"Think: {x} (is_final={is_final})")
|
||||
if streaming:
|
||||
r = AgentResponse(
|
||||
chunk_type="thought",
|
||||
content=x,
|
||||
end_of_message=is_final,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
else:
|
||||
r = AgentResponse(
|
||||
chunk_type="thought",
|
||||
content=x,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
await respond(r)
|
||||
return think
|
||||
|
||||
def make_observe_callback(self, respond, streaming):
|
||||
"""Create the observe callback for streaming/non-streaming."""
|
||||
async def observe(x, is_final=False):
|
||||
logger.debug(f"Observe: {x} (is_final={is_final})")
|
||||
if streaming:
|
||||
r = AgentResponse(
|
||||
chunk_type="observation",
|
||||
content=x,
|
||||
end_of_message=is_final,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
else:
|
||||
r = AgentResponse(
|
||||
chunk_type="observation",
|
||||
content=x,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
await respond(r)
|
||||
return observe
|
||||
|
||||
def make_answer_callback(self, respond, streaming):
|
||||
"""Create the answer callback for streaming/non-streaming."""
|
||||
async def answer(x):
|
||||
logger.debug(f"Answer: {x}")
|
||||
if streaming:
|
||||
r = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content=x,
|
||||
end_of_message=False,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
else:
|
||||
r = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content=x,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
await respond(r)
|
||||
return answer
|
||||
|
||||
# ---- Provenance emission ------------------------------------------------
|
||||
|
||||
async def emit_session_triples(self, flow, session_uri, question, user,
|
||||
collection, respond, streaming):
|
||||
"""Emit provenance triples for a new session."""
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
triples = set_graph(
|
||||
agent_session_triples(session_uri, question, timestamp),
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=session_uri,
|
||||
user=user,
|
||||
collection=collection,
|
||||
),
|
||||
triples=triples,
|
||||
))
|
||||
logger.debug(f"Emitted session triples for {session_uri}")
|
||||
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=session_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
async def emit_iteration_triples(self, flow, session_id, iteration_num,
|
||||
session_uri, act, request, respond,
|
||||
streaming):
|
||||
"""Emit provenance triples for an iteration and save to librarian."""
|
||||
iteration_uri = agent_iteration_uri(session_id, iteration_num)
|
||||
|
||||
if iteration_num > 1:
|
||||
iter_question_uri = None
|
||||
iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
else:
|
||||
iter_question_uri = session_uri
|
||||
iter_previous_uri = None
|
||||
|
||||
# Save thought to librarian
|
||||
thought_doc_id = None
|
||||
if act.thought:
|
||||
thought_doc_id = (
|
||||
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
|
||||
)
|
||||
try:
|
||||
await self.processor.save_answer_content(
|
||||
doc_id=thought_doc_id,
|
||||
user=request.user,
|
||||
content=act.thought,
|
||||
title=f"Agent Thought: {act.name}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save thought to librarian: {e}")
|
||||
thought_doc_id = None
|
||||
|
||||
# Save observation to librarian
|
||||
observation_doc_id = None
|
||||
if act.observation:
|
||||
observation_doc_id = (
|
||||
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
|
||||
)
|
||||
try:
|
||||
await self.processor.save_answer_content(
|
||||
doc_id=observation_doc_id,
|
||||
user=request.user,
|
||||
content=act.observation,
|
||||
title=f"Agent Observation: {act.name}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save observation to librarian: {e}")
|
||||
observation_doc_id = None
|
||||
|
||||
thought_entity_uri = agent_thought_uri(session_id, iteration_num)
|
||||
observation_entity_uri = agent_observation_uri(session_id, iteration_num)
|
||||
|
||||
iter_triples = set_graph(
|
||||
agent_iteration_triples(
|
||||
iteration_uri,
|
||||
question_uri=iter_question_uri,
|
||||
previous_uri=iter_previous_uri,
|
||||
action=act.name,
|
||||
arguments=act.arguments,
|
||||
thought_uri=thought_entity_uri if thought_doc_id else None,
|
||||
thought_document_id=thought_doc_id,
|
||||
observation_uri=observation_entity_uri if observation_doc_id else None,
|
||||
observation_document_id=observation_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=iteration_uri,
|
||||
user=request.user,
|
||||
collection=getattr(request, 'collection', 'default'),
|
||||
),
|
||||
triples=iter_triples,
|
||||
))
|
||||
logger.debug(f"Emitted iteration triples for {iteration_uri}")
|
||||
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=iteration_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
async def emit_final_triples(self, flow, session_id, iteration_num,
|
||||
session_uri, answer_text, request, respond,
|
||||
streaming):
|
||||
"""Emit provenance triples for the final answer and save to librarian."""
|
||||
final_uri = agent_final_uri(session_id)
|
||||
|
||||
if iteration_num > 1:
|
||||
final_question_uri = None
|
||||
final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
else:
|
||||
final_question_uri = session_uri
|
||||
final_previous_uri = None
|
||||
|
||||
# Save answer to librarian
|
||||
answer_doc_id = None
|
||||
if answer_text:
|
||||
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
|
||||
try:
|
||||
await self.processor.save_answer_content(
|
||||
doc_id=answer_doc_id,
|
||||
user=request.user,
|
||||
content=answer_text,
|
||||
title=f"Agent Answer: {request.question[:50]}...",
|
||||
)
|
||||
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
answer_doc_id = None
|
||||
|
||||
final_triples = set_graph(
|
||||
agent_final_triples(
|
||||
final_uri,
|
||||
question_uri=final_question_uri,
|
||||
previous_uri=final_previous_uri,
|
||||
document_id=answer_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=final_uri,
|
||||
user=request.user,
|
||||
collection=getattr(request, 'collection', 'default'),
|
||||
),
|
||||
triples=final_triples,
|
||||
))
|
||||
logger.debug(f"Emitted final triples for {final_uri}")
|
||||
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=final_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
# ---- Response helpers ---------------------------------------------------
|
||||
|
||||
async def prompt_as_answer(self, client, prompt_id, variables,
|
||||
respond, streaming):
|
||||
"""Call a prompt template, forwarding chunks as answer
|
||||
AgentResponse messages when streaming is enabled.
|
||||
|
||||
Returns the full accumulated answer text (needed for provenance).
|
||||
"""
|
||||
if streaming:
|
||||
accumulated = []
|
||||
|
||||
async def on_chunk(text, end_of_stream):
|
||||
if text:
|
||||
accumulated.append(text)
|
||||
await respond(AgentResponse(
|
||||
chunk_type="answer",
|
||||
content=text,
|
||||
end_of_message=False,
|
||||
end_of_dialog=False,
|
||||
))
|
||||
|
||||
await client.prompt(
|
||||
id=prompt_id,
|
||||
variables=variables,
|
||||
streaming=True,
|
||||
chunk_callback=on_chunk,
|
||||
)
|
||||
|
||||
return "".join(accumulated)
|
||||
else:
|
||||
return await client.prompt(
|
||||
id=prompt_id,
|
||||
variables=variables,
|
||||
)
|
||||
|
||||
async def send_final_response(self, respond, streaming, answer_text,
|
||||
already_streamed=False):
|
||||
"""Send the answer content and end-of-dialog marker.
|
||||
|
||||
Args:
|
||||
already_streamed: If True, answer chunks were already sent
|
||||
via streaming callbacks (e.g. ReactPattern). Only the
|
||||
end-of-dialog marker is emitted.
|
||||
"""
|
||||
if streaming and not already_streamed:
|
||||
# Answer wasn't streamed yet — send it as a chunk first
|
||||
if answer_text:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="answer",
|
||||
content=answer_text,
|
||||
end_of_message=False,
|
||||
end_of_dialog=False,
|
||||
))
|
||||
if streaming:
|
||||
# End-of-dialog marker
|
||||
await respond(AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
))
|
||||
else:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="answer",
|
||||
content=answer_text,
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
))
|
||||
|
||||
def build_next_request(self, request, history, session_id, collection,
|
||||
streaming, next_state):
|
||||
"""Build the AgentRequest for the next iteration."""
|
||||
return AgentRequest(
|
||||
question=request.question,
|
||||
state=next_state,
|
||||
group=getattr(request, 'group', []),
|
||||
history=[
|
||||
AgentStep(
|
||||
thought=h.thought,
|
||||
action=h.name,
|
||||
arguments={k: str(v) for k, v in h.arguments.items()},
|
||||
observation=h.observation,
|
||||
)
|
||||
for h in history
|
||||
],
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
# Preserve orchestration fields
|
||||
conversation_id=getattr(request, 'conversation_id', ''),
|
||||
pattern=getattr(request, 'pattern', ''),
|
||||
task_type=getattr(request, 'task_type', ''),
|
||||
framing=getattr(request, 'framing', ''),
|
||||
correlation_id=getattr(request, 'correlation_id', ''),
|
||||
parent_session_id=getattr(request, 'parent_session_id', ''),
|
||||
subagent_goal=getattr(request, 'subagent_goal', ''),
|
||||
expected_siblings=getattr(request, 'expected_siblings', 0),
|
||||
)
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
"""
|
||||
Perform one iteration of this pattern.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
349
trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py
Normal file
349
trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
"""
|
||||
PlanThenExecutePattern — structured planning followed by step execution.
|
||||
|
||||
Phase 1 (planning): LLM produces a structured plan of steps.
|
||||
Phase 2 (execution): Each step is executed via single-shot tool call.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, PlanStep
|
||||
|
||||
from ..react.types import Action
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlanThenExecutePattern(PatternBase):
|
||||
"""
|
||||
Plan-then-Execute pattern.
|
||||
|
||||
History tracks the current phase via AgentStep.step_type:
|
||||
- "plan" step: contains the plan in step.plan
|
||||
- "execute" step: a normal react iteration executing a plan step
|
||||
|
||||
On the first call (empty history), a planning iteration is run.
|
||||
Subsequent calls execute the next pending plan step via ReACT.
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
collection = getattr(request, 'collection', 'default')
|
||||
|
||||
history = self.build_history(request)
|
||||
iteration_num = len(request.history) + 1
|
||||
session_uri = self.processor.provenance_session_uri(session_id)
|
||||
|
||||
# Emit session provenance on first iteration
|
||||
if iteration_num == 1:
|
||||
await self.emit_session_triples(
|
||||
flow, session_uri, request.question,
|
||||
request.user, collection, respond, streaming,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"PlanThenExecutePattern iteration {iteration_num}: "
|
||||
f"{request.question}"
|
||||
)
|
||||
|
||||
if iteration_num >= self.processor.max_iterations:
|
||||
raise RuntimeError("Too many agent iterations")
|
||||
|
||||
# Determine current phase by checking history for a plan step
|
||||
plan = self._extract_plan(request.history)
|
||||
|
||||
if plan is None:
|
||||
await self._planning_iteration(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming, session_uri,
|
||||
iteration_num,
|
||||
)
|
||||
else:
|
||||
await self._execution_iteration(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming, session_uri,
|
||||
iteration_num, plan,
|
||||
)
|
||||
|
||||
def _extract_plan(self, history):
|
||||
"""Find the most recent plan from history.
|
||||
|
||||
Checks execute steps first (they carry the updated plan with
|
||||
completion statuses), then falls back to the original plan step.
|
||||
"""
|
||||
if not history:
|
||||
return None
|
||||
for step in reversed(history):
|
||||
if step.plan:
|
||||
return list(step.plan)
|
||||
return None
|
||||
|
||||
def _find_next_pending_step(self, plan):
|
||||
"""Return index of the next pending step, or None if all done."""
|
||||
for i, step in enumerate(plan):
|
||||
if getattr(step, 'status', 'pending') == 'pending':
|
||||
return i
|
||||
return None
|
||||
|
||||
async def _planning_iteration(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
"""Ask the LLM to produce a structured plan."""
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
|
||||
tools = self.filter_tools(self.processor.agent.tools, request)
|
||||
framing = getattr(request, 'framing', '')
|
||||
|
||||
context = self.make_context(flow, request.user)
|
||||
client = context("prompt-request")
|
||||
|
||||
# Use the plan-create prompt template
|
||||
plan_steps = await client.prompt(
|
||||
id="plan-create",
|
||||
variables={
|
||||
"question": request.question,
|
||||
"framing": framing,
|
||||
"tools": [
|
||||
{"name": t.name, "description": t.description}
|
||||
for t in tools.values()
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Validate we got a list
|
||||
if not isinstance(plan_steps, list) or not plan_steps:
|
||||
logger.warning("plan-create returned invalid result, falling back to single step")
|
||||
plan_steps = [{"goal": "Answer the question directly", "tool_hint": "", "depends_on": []}]
|
||||
|
||||
# Emit thought about the plan
|
||||
thought_text = f"Created plan with {len(plan_steps)} steps"
|
||||
await think(thought_text, is_final=True)
|
||||
|
||||
# Build PlanStep objects
|
||||
plan_agent_steps = [
|
||||
PlanStep(
|
||||
goal=ps.get("goal", ""),
|
||||
tool_hint=ps.get("tool_hint", ""),
|
||||
depends_on=ps.get("depends_on", []),
|
||||
status="pending",
|
||||
result="",
|
||||
)
|
||||
for ps in plan_steps
|
||||
]
|
||||
|
||||
# Create a plan step in history
|
||||
plan_history_step = AgentStep(
|
||||
thought=thought_text,
|
||||
action="plan",
|
||||
arguments={},
|
||||
observation=json.dumps(plan_steps),
|
||||
step_type="plan",
|
||||
plan=plan_agent_steps,
|
||||
)
|
||||
|
||||
# Build next request with plan in history
|
||||
new_history = list(request.history) + [plan_history_step]
|
||||
r = AgentRequest(
|
||||
question=request.question,
|
||||
state=request.state,
|
||||
group=getattr(request, 'group', []),
|
||||
history=new_history,
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
conversation_id=getattr(request, 'conversation_id', ''),
|
||||
pattern=getattr(request, 'pattern', ''),
|
||||
task_type=getattr(request, 'task_type', ''),
|
||||
framing=getattr(request, 'framing', ''),
|
||||
correlation_id=getattr(request, 'correlation_id', ''),
|
||||
parent_session_id=getattr(request, 'parent_session_id', ''),
|
||||
subagent_goal=getattr(request, 'subagent_goal', ''),
|
||||
expected_siblings=getattr(request, 'expected_siblings', 0),
|
||||
)
|
||||
await next(r)
|
||||
|
||||
async def _execution_iteration(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan):
|
||||
"""Execute the next pending plan step via single-shot tool call."""
|
||||
|
||||
pending_idx = self._find_next_pending_step(plan)
|
||||
|
||||
if pending_idx is None:
|
||||
# All steps done — synthesise final answer
|
||||
await self._synthesise(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan,
|
||||
)
|
||||
return
|
||||
|
||||
current_step = plan[pending_idx]
|
||||
goal = getattr(current_step, 'goal', '') or str(current_step)
|
||||
|
||||
logger.info(f"Executing plan step {pending_idx}: {goal}")
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
observe = self.make_observe_callback(respond, streaming)
|
||||
|
||||
# Gather results from dependencies
|
||||
previous_results = []
|
||||
depends_on = getattr(current_step, 'depends_on', [])
|
||||
if depends_on:
|
||||
for dep_idx in depends_on:
|
||||
if 0 <= dep_idx < len(plan):
|
||||
dep_step = plan[dep_idx]
|
||||
dep_result = getattr(dep_step, 'result', '')
|
||||
if dep_result:
|
||||
previous_results.append({
|
||||
"index": dep_idx,
|
||||
"result": dep_result,
|
||||
})
|
||||
|
||||
tools = self.filter_tools(self.processor.agent.tools, request)
|
||||
context = self.make_context(flow, request.user)
|
||||
client = context("prompt-request")
|
||||
|
||||
# Single-shot: ask LLM which tool + arguments to use for this goal
|
||||
tool_call = await client.prompt(
|
||||
id="plan-step-execute",
|
||||
variables={
|
||||
"goal": goal,
|
||||
"previous_results": previous_results,
|
||||
"tools": [
|
||||
{
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"arguments": [
|
||||
{"name": a.name, "type": a.type, "description": a.description}
|
||||
for a in t.arguments
|
||||
],
|
||||
}
|
||||
for t in tools.values()
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
tool_name = tool_call.get("tool", "")
|
||||
tool_arguments = tool_call.get("arguments", {})
|
||||
|
||||
await think(
|
||||
f"Step {pending_idx}: {goal} → calling {tool_name}",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Invoke the tool directly
|
||||
if tool_name in tools:
|
||||
tool = tools[tool_name]
|
||||
resp = await tool.implementation(context).invoke(**tool_arguments)
|
||||
step_result = resp.strip() if isinstance(resp, str) else str(resp).strip()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Plan step {pending_idx}: LLM selected unknown tool "
|
||||
f"'{tool_name}', available: {list(tools.keys())}"
|
||||
)
|
||||
step_result = f"Error: tool '{tool_name}' not found"
|
||||
|
||||
await observe(step_result, is_final=True)
|
||||
|
||||
# Update plan step status
|
||||
plan[pending_idx] = PlanStep(
|
||||
goal=goal,
|
||||
tool_hint=getattr(current_step, 'tool_hint', ''),
|
||||
depends_on=getattr(current_step, 'depends_on', []),
|
||||
status="completed",
|
||||
result=step_result,
|
||||
)
|
||||
|
||||
# Emit iteration provenance
|
||||
prov_act = Action(
|
||||
thought=f"Plan step {pending_idx}: {goal}",
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
observation=step_result,
|
||||
)
|
||||
await self.emit_iteration_triples(
|
||||
flow, session_id, iteration_num, session_uri,
|
||||
prov_act, request, respond, streaming,
|
||||
)
|
||||
|
||||
# Build execution step for history
|
||||
exec_step = AgentStep(
|
||||
thought=f"Executing plan step {pending_idx}: {goal}",
|
||||
action=tool_name,
|
||||
arguments={k: str(v) for k, v in tool_arguments.items()},
|
||||
observation=step_result,
|
||||
step_type="execute",
|
||||
plan=plan,
|
||||
)
|
||||
|
||||
new_history = list(request.history) + [exec_step]
|
||||
|
||||
r = AgentRequest(
|
||||
question=request.question,
|
||||
state=request.state,
|
||||
group=getattr(request, 'group', []),
|
||||
history=new_history,
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
conversation_id=getattr(request, 'conversation_id', ''),
|
||||
pattern=getattr(request, 'pattern', ''),
|
||||
task_type=getattr(request, 'task_type', ''),
|
||||
framing=getattr(request, 'framing', ''),
|
||||
correlation_id=getattr(request, 'correlation_id', ''),
|
||||
parent_session_id=getattr(request, 'parent_session_id', ''),
|
||||
subagent_goal=getattr(request, 'subagent_goal', ''),
|
||||
expected_siblings=getattr(request, 'expected_siblings', 0),
|
||||
)
|
||||
await next(r)
|
||||
|
||||
async def _synthesise(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan):
|
||||
"""Synthesise a final answer from all completed plan step results."""
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
framing = getattr(request, 'framing', '')
|
||||
|
||||
context = self.make_context(flow, request.user)
|
||||
client = context("prompt-request")
|
||||
|
||||
# Use the plan-synthesise prompt template
|
||||
steps_data = []
|
||||
for i, step in enumerate(plan):
|
||||
steps_data.append({
|
||||
"index": i,
|
||||
"goal": getattr(step, 'goal', f'Step {i}'),
|
||||
"result": getattr(step, 'result', ''),
|
||||
})
|
||||
|
||||
await think("Synthesising final answer from plan results", is_final=True)
|
||||
|
||||
response_text = await self.prompt_as_answer(
|
||||
client, "plan-synthesise",
|
||||
variables={
|
||||
"question": request.question,
|
||||
"framing": framing,
|
||||
"steps": steps_data,
|
||||
},
|
||||
respond=respond,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
await self.emit_final_triples(
|
||||
flow, session_id, iteration_num, session_uri,
|
||||
response_text, request, respond, streaming,
|
||||
)
|
||||
await self.send_final_response(
|
||||
respond, streaming, response_text, already_streamed=streaming,
|
||||
)
|
||||
134
trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py
Normal file
134
trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
ReactPattern — extracted from the existing agent_manager.py.
|
||||
|
||||
Implements the ReACT (Reasoning + Acting) loop: think, select a tool,
|
||||
observe the result, repeat until a final answer is produced.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep
|
||||
|
||||
from ..react.agent_manager import AgentManager
|
||||
from ..react.types import Action, Final
|
||||
from ..tool_filter import get_next_state
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReactPattern(PatternBase):
|
||||
"""
|
||||
ReACT pattern: interleaved reasoning and action.
|
||||
|
||||
Each iterate() call performs one reason/act cycle. If the LLM
|
||||
produces a Final answer the dialog completes; otherwise the action
|
||||
result is appended to history and a next-request is emitted.
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
collection = getattr(request, 'collection', 'default')
|
||||
|
||||
history = self.build_history(request)
|
||||
iteration_num = len(history) + 1
|
||||
session_uri = self.processor.provenance_session_uri(session_id)
|
||||
|
||||
# Emit session provenance on first iteration
|
||||
if iteration_num == 1:
|
||||
await self.emit_session_triples(
|
||||
flow, session_uri, request.question,
|
||||
request.user, collection, respond, streaming,
|
||||
)
|
||||
|
||||
logger.info(f"ReactPattern iteration {iteration_num}: {request.question}")
|
||||
|
||||
if len(history) >= self.processor.max_iterations:
|
||||
raise RuntimeError("Too many agent iterations")
|
||||
|
||||
# Build callbacks
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
observe = self.make_observe_callback(respond, streaming)
|
||||
answer_cb = self.make_answer_callback(respond, streaming)
|
||||
|
||||
# Filter tools
|
||||
filtered_tools = self.filter_tools(
|
||||
self.processor.agent.tools, request,
|
||||
)
|
||||
logger.info(
|
||||
f"Filtered from {len(self.processor.agent.tools)} "
|
||||
f"to {len(filtered_tools)} available tools"
|
||||
)
|
||||
|
||||
# Create temporary agent with filtered tools and optional framing
|
||||
additional_context = self.processor.agent.additional_context
|
||||
framing = getattr(request, 'framing', '')
|
||||
if framing:
|
||||
if additional_context:
|
||||
additional_context = f"{additional_context}\n\n{framing}"
|
||||
else:
|
||||
additional_context = framing
|
||||
|
||||
temp_agent = AgentManager(
|
||||
tools=filtered_tools,
|
||||
additional_context=additional_context,
|
||||
)
|
||||
|
||||
context = self.make_context(flow, request.user)
|
||||
|
||||
act = await temp_agent.react(
|
||||
question=request.question,
|
||||
history=history,
|
||||
think=think,
|
||||
observe=observe,
|
||||
answer=answer_cb,
|
||||
context=context,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
logger.debug(f"Action: {act}")
|
||||
|
||||
if isinstance(act, Final):
|
||||
|
||||
if isinstance(act.final, str):
|
||||
f = act.final
|
||||
else:
|
||||
f = json.dumps(act.final)
|
||||
|
||||
# Emit final provenance
|
||||
await self.emit_final_triples(
|
||||
flow, session_id, iteration_num, session_uri,
|
||||
f, request, respond, streaming,
|
||||
)
|
||||
|
||||
await self.send_final_response(
|
||||
respond, streaming, f, already_streamed=streaming,
|
||||
)
|
||||
return
|
||||
|
||||
# Not final — emit iteration provenance and send next request
|
||||
await self.emit_iteration_triples(
|
||||
flow, session_id, iteration_num, session_uri,
|
||||
act, request, respond, streaming,
|
||||
)
|
||||
|
||||
history.append(act)
|
||||
|
||||
# Handle state transitions
|
||||
next_state = request.state
|
||||
if act.name in filtered_tools:
|
||||
executed_tool = filtered_tools[act.name]
|
||||
next_state = get_next_state(executed_tool, request.state or "undefined")
|
||||
|
||||
r = self.build_next_request(
|
||||
request, history, session_id, collection,
|
||||
streaming, next_state,
|
||||
)
|
||||
await next(r)
|
||||
|
||||
logger.debug("ReactPattern iteration complete")
|
||||
511
trustgraph-flow/trustgraph/agent/orchestrator/service.py
Normal file
511
trustgraph-flow/trustgraph/agent/orchestrator/service.py
Normal file
|
|
@ -0,0 +1,511 @@
|
|||
"""
|
||||
Agent orchestrator service — multi-pattern drop-in replacement for
|
||||
agent-manager-react.
|
||||
|
||||
Uses the same service identity and Pulsar queues. Adds meta-routing
|
||||
to select between ReactPattern, PlanThenExecutePattern, and
|
||||
SupervisorPattern at runtime.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import functools
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||
from ... base import ProducerSpec
|
||||
from ... base import Consumer, Producer
|
||||
from ... base import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from ... schema import Triples, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
|
||||
from trustgraph.provenance import (
|
||||
agent_session_uri,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
from ..react.tools import (
|
||||
KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl,
|
||||
StructuredQueryImpl, RowEmbeddingsQueryImpl, ToolServiceImpl,
|
||||
)
|
||||
from ..react.agent_manager import AgentManager
|
||||
from ..tool_filter import validate_tool_config
|
||||
from ..react.types import Final, Action, Tool, Argument
|
||||
|
||||
from . meta_router import MetaRouter
|
||||
from . pattern_base import PatternBase, UserAwareContext
|
||||
from . react_pattern import ReactPattern
|
||||
from . plan_pattern import PlanThenExecutePattern
|
||||
from . supervisor_pattern import SupervisorPattern
|
||||
from . aggregator import Aggregator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "agent-manager"
|
||||
default_max_iterations = 10
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
|
||||
class Processor(AgentService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
|
||||
self.max_iterations = int(
|
||||
params.get("max_iterations", default_max_iterations)
|
||||
)
|
||||
|
||||
self.config_key = params.get("config_type", "agent")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"max_iterations": self.max_iterations,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.agent = AgentManager(
|
||||
tools={},
|
||||
additional_context="",
|
||||
)
|
||||
|
||||
self.tool_service_clients = {}
|
||||
|
||||
# Patterns
|
||||
self.react_pattern = ReactPattern(self)
|
||||
self.plan_pattern = PlanThenExecutePattern(self)
|
||||
self.supervisor_pattern = SupervisorPattern(self)
|
||||
|
||||
# Aggregator for supervisor fan-in
|
||||
self.aggregator = Aggregator()
|
||||
|
||||
# Meta-router (initialised on first config load)
|
||||
self.meta_router = None
|
||||
|
||||
self.config_handlers.append(self.on_tools_config)
|
||||
|
||||
self.register_specification(
|
||||
TextCompletionClientSpec(
|
||||
request_name="text-completion-request",
|
||||
response_name="text-completion-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
GraphRagClientSpec(
|
||||
request_name="graph-rag-request",
|
||||
response_name="graph-rag-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name="prompt-request",
|
||||
response_name="prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ToolClientSpec(
|
||||
request_name="mcp-tool-request",
|
||||
response_name="mcp-tool-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
StructuredQueryClientSpec(
|
||||
request_name="structured-query-request",
|
||||
response_name="structured-query-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name="embeddings-request",
|
||||
response_name="embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RowEmbeddingsQueryClientSpec(
|
||||
request_name="row-embeddings-query-request",
|
||||
response_name="row-embeddings-query-response",
|
||||
)
|
||||
)
|
||||
|
||||
# Explainability producer
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="explainability",
|
||||
schema=Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
self.pending_librarian_requests = {}
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id in self.pending_librarian_requests:
|
||||
future = self.pending_librarian_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None,
|
||||
timeout=120):
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or "Agent Answer",
|
||||
document_type="answer",
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_librarian_requests[request_id] = future
|
||||
|
||||
try:
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving answer: "
|
||||
f"{response.error.type}: {response.error.message}"
|
||||
)
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_librarian_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
||||
|
||||
def provenance_session_uri(self, session_id):
|
||||
return agent_session_uri(session_id)
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
logger.info(f"Loading configuration version {version}")
|
||||
|
||||
try:
|
||||
tools = {}
|
||||
|
||||
# Load tool-service configurations
|
||||
tool_services = {}
|
||||
if "tool-service" in config:
|
||||
for service_id, service_value in config["tool-service"].items():
|
||||
service_data = json.loads(service_value)
|
||||
tool_services[service_id] = service_data
|
||||
logger.debug(f"Loaded tool-service config: {service_id}")
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(tool_services)} tool-service configurations"
|
||||
)
|
||||
|
||||
# Load tool configurations
|
||||
if "tool" in config:
|
||||
for tool_id, tool_value in config["tool"].items():
|
||||
data = json.loads(tool_value)
|
||||
impl_id = data.get("type")
|
||||
name = data.get("name")
|
||||
|
||||
if impl_id == "knowledge-query":
|
||||
impl = functools.partial(
|
||||
KnowledgeQueryImpl,
|
||||
collection=data.get("collection"),
|
||||
)
|
||||
arguments = KnowledgeQueryImpl.get_arguments()
|
||||
elif impl_id == "text-completion":
|
||||
impl = TextCompletionImpl
|
||||
arguments = TextCompletionImpl.get_arguments()
|
||||
elif impl_id == "mcp-tool":
|
||||
config_args = data.get("arguments", [])
|
||||
arguments = [
|
||||
Argument(
|
||||
name=arg.get("name"),
|
||||
type=arg.get("type"),
|
||||
description=arg.get("description"),
|
||||
)
|
||||
for arg in config_args
|
||||
]
|
||||
impl = functools.partial(
|
||||
McpToolImpl,
|
||||
mcp_tool_id=data.get("mcp-tool"),
|
||||
arguments=arguments,
|
||||
)
|
||||
elif impl_id == "prompt":
|
||||
config_args = data.get("arguments", [])
|
||||
arguments = [
|
||||
Argument(
|
||||
name=arg.get("name"),
|
||||
type=arg.get("type"),
|
||||
description=arg.get("description"),
|
||||
)
|
||||
for arg in config_args
|
||||
]
|
||||
impl = functools.partial(
|
||||
PromptImpl,
|
||||
template_id=data.get("template"),
|
||||
arguments=arguments,
|
||||
)
|
||||
elif impl_id == "structured-query":
|
||||
impl = functools.partial(
|
||||
StructuredQueryImpl,
|
||||
collection=data.get("collection"),
|
||||
user=None,
|
||||
)
|
||||
arguments = StructuredQueryImpl.get_arguments()
|
||||
elif impl_id == "row-embeddings-query":
|
||||
impl = functools.partial(
|
||||
RowEmbeddingsQueryImpl,
|
||||
schema_name=data.get("schema-name"),
|
||||
collection=data.get("collection"),
|
||||
user=None,
|
||||
index_name=data.get("index-name"),
|
||||
limit=int(data.get("limit", 10)),
|
||||
)
|
||||
arguments = RowEmbeddingsQueryImpl.get_arguments()
|
||||
elif impl_id == "tool-service":
|
||||
service_ref = data.get("service")
|
||||
if not service_ref:
|
||||
raise RuntimeError(
|
||||
f"Tool {name} has type 'tool-service' "
|
||||
f"but no 'service' reference"
|
||||
)
|
||||
if service_ref not in tool_services:
|
||||
raise RuntimeError(
|
||||
f"Tool {name} references unknown "
|
||||
f"tool-service '{service_ref}'"
|
||||
)
|
||||
|
||||
service_config = tool_services[service_ref]
|
||||
request_queue = service_config.get("request-queue")
|
||||
response_queue = service_config.get("response-queue")
|
||||
if not request_queue or not response_queue:
|
||||
raise RuntimeError(
|
||||
f"Tool-service '{service_ref}' must define "
|
||||
f"'request-queue' and 'response-queue'"
|
||||
)
|
||||
|
||||
config_params = service_config.get("config-params", [])
|
||||
config_values = {}
|
||||
for param in config_params:
|
||||
param_name = (
|
||||
param.get("name")
|
||||
if isinstance(param, dict) else param
|
||||
)
|
||||
if param_name in data:
|
||||
config_values[param_name] = data[param_name]
|
||||
elif (
|
||||
isinstance(param, dict)
|
||||
and param.get("required", False)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Tool {name} missing required config "
|
||||
f"param '{param_name}'"
|
||||
)
|
||||
|
||||
config_args = data.get("arguments", [])
|
||||
arguments = [
|
||||
Argument(
|
||||
name=arg.get("name"),
|
||||
type=arg.get("type"),
|
||||
description=arg.get("description"),
|
||||
)
|
||||
for arg in config_args
|
||||
]
|
||||
|
||||
impl = functools.partial(
|
||||
ToolServiceImpl,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
config_values=config_values,
|
||||
arguments=arguments,
|
||||
processor=self,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tool type {impl_id} not known"
|
||||
)
|
||||
|
||||
validate_tool_config(data)
|
||||
|
||||
tools[name] = Tool(
|
||||
name=name,
|
||||
description=data.get("description"),
|
||||
implementation=impl,
|
||||
config=data,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
# Load additional context from agent config
|
||||
additional = None
|
||||
if self.config_key in config:
|
||||
agent_config = config[self.config_key]
|
||||
additional = agent_config.get("additional-context", None)
|
||||
|
||||
self.agent = AgentManager(
|
||||
tools=tools,
|
||||
additional_context=additional,
|
||||
)
|
||||
|
||||
# Re-initialise meta-router with config
|
||||
self.meta_router = MetaRouter(config=config)
|
||||
|
||||
logger.info(f"Loaded {len(tools)} tools")
|
||||
logger.info("Tool configuration reloaded.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"on_tools_config Exception: {e}", exc_info=True
|
||||
)
|
||||
logger.error("Configuration reload failed")
|
||||
|
||||
async def agent_request(self, request, respond, next, flow):
|
||||
|
||||
try:
|
||||
pattern = getattr(request, 'pattern', '') or ''
|
||||
|
||||
# If no pattern set and this is the first iteration, route
|
||||
if not pattern and not request.history:
|
||||
context = UserAwareContext(flow, request.user)
|
||||
|
||||
if self.meta_router:
|
||||
pattern, task_type, framing = await self.meta_router.route(
|
||||
request.question, context,
|
||||
)
|
||||
else:
|
||||
pattern = "react"
|
||||
task_type = "general"
|
||||
framing = ""
|
||||
|
||||
# Update request with routing decision
|
||||
request.pattern = pattern
|
||||
request.task_type = task_type
|
||||
request.framing = framing
|
||||
|
||||
logger.info(
|
||||
f"Routed to pattern={pattern}, "
|
||||
f"task_type={task_type}"
|
||||
)
|
||||
|
||||
# Dispatch to the selected pattern
|
||||
if pattern == "plan-then-execute":
|
||||
await self.plan_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
)
|
||||
elif pattern == "supervisor":
|
||||
await self.supervisor_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
)
|
||||
else:
|
||||
# Default to react
|
||||
await self.react_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(
|
||||
f"agent_request Exception: {e}", exc_info=True
|
||||
)
|
||||
|
||||
logger.debug("Send error response...")
|
||||
|
||||
error_obj = Error(
|
||||
type="agent-error",
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
r = AgentResponse(
|
||||
chunk_type="error",
|
||||
content=str(e),
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
error=error_obj,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
AgentService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-iterations',
|
||||
default=default_max_iterations,
|
||||
help=f'Maximum number of react iterations '
|
||||
f'(default: {default_max_iterations})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default="agent",
|
||||
help='Configuration key for prompts (default: agent)',
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
"""
|
||||
SupervisorPattern — decomposes a query into subagent goals, fans out,
|
||||
then synthesises results when all subagents complete.
|
||||
|
||||
Phase 1 (decompose): LLM breaks the query into independent sub-goals.
|
||||
Fan-out: Each sub-goal is emitted as a new AgentRequest on the agent
|
||||
request topic, carrying a correlation_id and parent_session_id.
|
||||
Phase 2 (synthesise): Triggered when the aggregator detects all
|
||||
subagents have completed. The supervisor fetches results and
|
||||
produces the final answer.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep
|
||||
|
||||
from ..react.types import Action, Final
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_SUBAGENTS = 5
|
||||
|
||||
|
||||
class SupervisorPattern(PatternBase):
|
||||
"""
|
||||
Supervisor pattern: decompose, fan-out, synthesise.
|
||||
|
||||
History tracks phase via AgentStep.step_type:
|
||||
- "decompose": the decomposition step (subagent goals in arguments)
|
||||
- "synthesise": triggered by aggregator with results in subagent_results
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
collection = getattr(request, 'collection', 'default')
|
||||
iteration_num = len(request.history) + 1
|
||||
session_uri = self.processor.provenance_session_uri(session_id)
|
||||
|
||||
# Emit session provenance on first iteration
|
||||
if iteration_num == 1:
|
||||
await self.emit_session_triples(
|
||||
flow, session_uri, request.question,
|
||||
request.user, collection, respond, streaming,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"SupervisorPattern iteration {iteration_num}: {request.question}"
|
||||
)
|
||||
|
||||
# Check if this is a synthesis request (has subagent_results)
|
||||
has_results = bool(
|
||||
request.history
|
||||
and any(
|
||||
getattr(h, 'step_type', '') == 'decompose'
|
||||
for h in request.history
|
||||
)
|
||||
and any(
|
||||
getattr(h, 'subagent_results', None)
|
||||
for h in request.history
|
||||
)
|
||||
)
|
||||
|
||||
if has_results:
|
||||
await self._synthesise(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num,
|
||||
)
|
||||
else:
|
||||
await self._decompose_and_fanout(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num,
|
||||
)
|
||||
|
||||
async def _decompose_and_fanout(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
"""Decompose the question into sub-goals and fan out subagents."""
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
framing = getattr(request, 'framing', '')
|
||||
|
||||
tools = self.filter_tools(self.processor.agent.tools, request)
|
||||
|
||||
context = self.make_context(flow, request.user)
|
||||
client = context("prompt-request")
|
||||
|
||||
# Use the supervisor-decompose prompt template
|
||||
goals = await client.prompt(
|
||||
id="supervisor-decompose",
|
||||
variables={
|
||||
"question": request.question,
|
||||
"framing": framing,
|
||||
"max_subagents": MAX_SUBAGENTS,
|
||||
"tools": [
|
||||
{"name": t.name, "description": t.description}
|
||||
for t in tools.values()
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Validate result
|
||||
if not isinstance(goals, list):
|
||||
goals = []
|
||||
goals = [g for g in goals if isinstance(g, str)]
|
||||
goals = goals[:MAX_SUBAGENTS]
|
||||
|
||||
if not goals:
|
||||
goals = [request.question]
|
||||
|
||||
await think(
|
||||
f"Decomposed into {len(goals)} sub-goals: {goals}",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Generate correlation ID for this fan-out
|
||||
correlation_id = str(uuid.uuid4())
|
||||
|
||||
# Emit decomposition provenance
|
||||
decompose_act = Action(
|
||||
thought=f"Decomposed into {len(goals)} sub-goals",
|
||||
name="decompose",
|
||||
arguments={"goals": json.dumps(goals), "correlation_id": correlation_id},
|
||||
observation=f"Fanning out {len(goals)} subagents",
|
||||
)
|
||||
await self.emit_iteration_triples(
|
||||
flow, session_id, iteration_num, session_uri,
|
||||
decompose_act, request, respond, streaming,
|
||||
)
|
||||
|
||||
# Fan out: emit a subagent request for each goal
|
||||
for i, goal in enumerate(goals):
|
||||
subagent_session = str(uuid.uuid4())
|
||||
sub_request = AgentRequest(
|
||||
question=goal,
|
||||
state="",
|
||||
group=getattr(request, 'group', []),
|
||||
history=[],
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
streaming=False, # Subagents don't stream
|
||||
session_id=subagent_session,
|
||||
conversation_id=getattr(request, 'conversation_id', ''),
|
||||
pattern="react", # Subagents use react by default
|
||||
task_type=getattr(request, 'task_type', ''),
|
||||
framing=getattr(request, 'framing', ''),
|
||||
correlation_id=correlation_id,
|
||||
parent_session_id=session_id,
|
||||
subagent_goal=goal,
|
||||
expected_siblings=len(goals),
|
||||
)
|
||||
await next(sub_request)
|
||||
logger.info(f"Fan-out: emitted subagent {i} for goal: {goal}")
|
||||
|
||||
# NOTE: The supervisor stops here. The aggregator will detect
|
||||
# when all subagents complete and emit a synthesis request
|
||||
# with the results populated.
|
||||
logger.info(
|
||||
f"Supervisor fan-out complete: {len(goals)} subagents, "
|
||||
f"correlation_id={correlation_id}"
|
||||
)
|
||||
|
||||
async def _synthesise(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
"""Synthesise final answer from subagent results."""
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
framing = getattr(request, 'framing', '')
|
||||
|
||||
# Collect subagent results from history
|
||||
subagent_results = {}
|
||||
for step in request.history:
|
||||
results = getattr(step, 'subagent_results', None)
|
||||
if results:
|
||||
subagent_results.update(results)
|
||||
|
||||
if not subagent_results:
|
||||
logger.warning("Synthesis called with no subagent results")
|
||||
subagent_results = {"(no results)": "No subagent results available"}
|
||||
|
||||
context = self.make_context(flow, request.user)
|
||||
client = context("prompt-request")
|
||||
|
||||
await think("Synthesising final answer from sub-agent results", is_final=True)
|
||||
|
||||
response_text = await self.prompt_as_answer(
|
||||
client, "supervisor-synthesise",
|
||||
variables={
|
||||
"question": request.question,
|
||||
"framing": framing,
|
||||
"results": [
|
||||
{"goal": goal, "result": result}
|
||||
for goal, result in subagent_results.items()
|
||||
],
|
||||
},
|
||||
respond=respond,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
await self.emit_final_triples(
|
||||
flow, session_id, iteration_num, session_uri,
|
||||
response_text, request, respond, streaming,
|
||||
)
|
||||
await self.send_final_response(
|
||||
respond, streaming, response_text, already_streamed=streaming,
|
||||
)
|
||||
|
|
@ -485,25 +485,16 @@ class Processor(AgentService):
|
|||
logger.debug(f"Think: {x} (is_final={is_final})")
|
||||
|
||||
if streaming:
|
||||
# Streaming format
|
||||
r = AgentResponse(
|
||||
chunk_type="thought",
|
||||
content=x,
|
||||
end_of_message=is_final,
|
||||
end_of_dialog=False,
|
||||
# Legacy fields for backward compatibility
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=x,
|
||||
observation=None,
|
||||
)
|
||||
else:
|
||||
# Non-streaming format
|
||||
r = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=x,
|
||||
observation=None,
|
||||
chunk_type="thought",
|
||||
content=x,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
|
@ -515,25 +506,16 @@ class Processor(AgentService):
|
|||
logger.debug(f"Observe: {x} (is_final={is_final})")
|
||||
|
||||
if streaming:
|
||||
# Streaming format
|
||||
r = AgentResponse(
|
||||
chunk_type="observation",
|
||||
content=x,
|
||||
end_of_message=is_final,
|
||||
end_of_dialog=False,
|
||||
# Legacy fields for backward compatibility
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=x,
|
||||
)
|
||||
else:
|
||||
# Non-streaming format
|
||||
r = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=x,
|
||||
chunk_type="observation",
|
||||
content=x,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
|
@ -545,25 +527,16 @@ class Processor(AgentService):
|
|||
logger.debug(f"Answer: {x}")
|
||||
|
||||
if streaming:
|
||||
# Streaming format
|
||||
r = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content=x,
|
||||
end_of_message=False, # More chunks may follow
|
||||
end_of_message=False,
|
||||
end_of_dialog=False,
|
||||
# Legacy fields for backward compatibility
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None,
|
||||
)
|
||||
else:
|
||||
# Non-streaming format - shouldn't normally be called
|
||||
r = AgentResponse(
|
||||
answer=x,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None,
|
||||
chunk_type="answer",
|
||||
content=x,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
|
@ -677,25 +650,17 @@ class Processor(AgentService):
|
|||
))
|
||||
|
||||
if streaming:
|
||||
# Streaming format - send end-of-dialog marker
|
||||
# Answer chunks were already sent via answer() callback during parsing
|
||||
# End-of-dialog marker — answer chunks already sent via callback
|
||||
r = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="", # Empty content, just marking end of dialog
|
||||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
# Legacy fields set to None - answer already sent via streaming chunks
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
)
|
||||
else:
|
||||
# Non-streaming format - send complete answer
|
||||
r = AgentResponse(
|
||||
answer=act.final,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None,
|
||||
chunk_type="answer",
|
||||
content=f,
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
|
@ -833,21 +798,13 @@ class Processor(AgentService):
|
|||
# Check if streaming was enabled (may not be set if error occurred early)
|
||||
streaming = getattr(request, 'streaming', False) if 'request' in locals() else False
|
||||
|
||||
if streaming:
|
||||
# Streaming format
|
||||
r = AgentResponse(
|
||||
chunk_type="error",
|
||||
content=str(e),
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
# Legacy fields for backward compatibility
|
||||
error=error_obj,
|
||||
)
|
||||
else:
|
||||
# Legacy format
|
||||
r = AgentResponse(
|
||||
error=error_obj,
|
||||
)
|
||||
r = AgentResponse(
|
||||
chunk_type="error",
|
||||
content=str(e),
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
error=error_obj,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue