mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Add agent explainability instrumentation and unify envelope field naming (#795)
Addresses recommendations from the UX developer's agent experience report. Adds provenance predicates, DAG structure changes, error resilience, and a published OWL ontology. Explainability additions: - Tool candidates: tg:toolCandidate on Analysis events lists the tools visible to the LLM for each iteration (names only, descriptions in config) - Termination reason: tg:terminationReason on Conclusion/Synthesis events (final-answer, plan-complete, subagents-complete) - Step counter: tg:stepNumber on iteration events - Pattern decision: new tg:PatternDecision entity in the DAG between session and first iteration, carrying tg:pattern and tg:taskType - Latency: tg:llmDurationMs on Analysis events, tg:toolDurationMs on Observation events - Token counts on events: tg:inToken/tg:outToken/tg:llmModel on Grounding, Focus, Synthesis, and Analysis events - Tool/parse errors: tg:toolError on Observation events with tg:Error mixin type. Parse failures return as error observations instead of crashing the agent, giving it a chance to retry. Envelope unification: - Rename chunk_type to message_type across AgentResponse schema, translator, SDK types, socket clients, CLI, and all tests. Agent and RAG services now both use message_type on the wire. Ontology: - specs/ontology/trustgraph.ttl — OWL vocabulary covering all 26 classes, 7 object properties, and 36+ datatype properties including new predicates. DAG structure tests: - tests/unit/test_provenance/test_dag_structure.py verifies the wasDerivedFrom chain for GraphRAG, DocumentRAG, and all three agent patterns (react, plan, supervisor) including the pattern-decision link.
This commit is contained in:
parent
14e49d83c7
commit
d2751553a3
42 changed files with 1577 additions and 205 deletions
|
|
@ -78,10 +78,10 @@ class TestAgentServiceNonStreaming:
|
|||
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
r for r in sent_responses if r.message_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
r for r in sent_responses if r.message_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session, iteration, observation, and final
|
||||
|
|
@ -93,7 +93,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check thought message
|
||||
thought_response = content_responses[0]
|
||||
assert isinstance(thought_response, AgentResponse)
|
||||
assert thought_response.chunk_type == "thought"
|
||||
assert thought_response.message_type == "thought"
|
||||
assert thought_response.content == "I need to solve this."
|
||||
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
|
||||
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
|
||||
|
|
@ -101,7 +101,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check observation message
|
||||
observation_response = content_responses[1]
|
||||
assert isinstance(observation_response, AgentResponse)
|
||||
assert observation_response.chunk_type == "observation"
|
||||
assert observation_response.message_type == "observation"
|
||||
assert observation_response.content == "The answer is 4."
|
||||
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
|
||||
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
|
||||
|
|
@ -168,10 +168,10 @@ class TestAgentServiceNonStreaming:
|
|||
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
r for r in sent_responses if r.message_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
r for r in sent_responses if r.message_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session and final
|
||||
|
|
@ -183,7 +183,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check final answer message
|
||||
answer_response = content_responses[0]
|
||||
assert isinstance(answer_response, AgentResponse)
|
||||
assert answer_response.chunk_type == "answer"
|
||||
assert answer_response.message_type == "answer"
|
||||
assert answer_response.content == "4"
|
||||
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
|
||||
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class TestThinkCallbackMessageId:
|
|||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "thought"
|
||||
assert responses[0].message_type == "thought"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_think_has_message_id(self, pattern):
|
||||
|
|
@ -58,7 +58,7 @@ class TestObserveCallbackMessageId:
|
|||
await observe("result", is_final=True)
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "observation"
|
||||
assert responses[0].message_type == "observation"
|
||||
|
||||
|
||||
class TestAnswerCallbackMessageId:
|
||||
|
|
@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId:
|
|||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "answer"
|
||||
assert responses[0].message_type == "answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_default(self, pattern):
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def collect_explain_events(respond_mock):
|
|||
events = []
|
||||
for call in respond_mock.call_args_list:
|
||||
resp = call[0][0]
|
||||
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
|
||||
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
|
||||
events.append({
|
||||
"explain_id": resp.explain_id,
|
||||
"explain_graph": resp.explain_graph,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_thought_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/thought",
|
||||
|
|
@ -31,7 +31,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_observation_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "observation",
|
||||
"message_type": "observation",
|
||||
"content": "result",
|
||||
"end_of_message": True,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/observation",
|
||||
|
|
@ -42,7 +42,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_answer_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"message_type": "answer",
|
||||
"content": "the answer",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False,
|
||||
|
|
@ -54,7 +54,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_thought_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
}
|
||||
|
|
@ -64,7 +64,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_answer_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"message_type": "answer",
|
||||
"content": "answer",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
message_type="explain",
|
||||
content="",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
|
|
@ -179,7 +179,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
message_type="thought",
|
||||
content="I need to think...",
|
||||
)
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_dialog=False,
|
||||
|
|
@ -203,7 +203,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
message_type="answer",
|
||||
content="The answer is...",
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
|
|
|||
592
tests/unit/test_provenance/test_dag_structure.py
Normal file
592
tests/unit/test_provenance/test_dag_structure.py
Normal file
|
|
@ -0,0 +1,592 @@
|
|||
"""
|
||||
DAG structure tests for provenance chains.
|
||||
|
||||
Verifies that the wasDerivedFrom chain has the expected shape for each
|
||||
service. These tests catch structural regressions when new entities are
|
||||
inserted into the chain (e.g. PatternDecision between session and first
|
||||
iteration).
|
||||
|
||||
Expected chains:
|
||||
|
||||
GraphRAG: question → grounding → exploration → focus → synthesis
|
||||
DocumentRAG: question → grounding → exploration → synthesis
|
||||
Agent React: session → pattern-decision → iteration → (observation → iteration)* → final
|
||||
Agent Plan: session → pattern-decision → plan → step-result(s) → synthesis
|
||||
Agent Super: session → pattern-decision → decomposition → (fan-out) → finding(s) → synthesis
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
Triple, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
|
||||
TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
|
||||
TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
|
||||
TG_OBSERVATION_TYPE,
|
||||
TG_PATTERN, TG_TASK_TYPE,
|
||||
)
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _collect_events(events):
|
||||
"""Build a dict of explain_id → {types, derived_from, triples}."""
|
||||
result = {}
|
||||
for ev in events:
|
||||
eid = ev["explain_id"]
|
||||
triples = ev["triples"]
|
||||
types = {
|
||||
t.o.iri for t in triples
|
||||
if t.s.iri == eid and t.p.iri == RDF_TYPE
|
||||
}
|
||||
parents = [
|
||||
t.o.iri for t in triples
|
||||
if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
|
||||
]
|
||||
result[eid] = {
|
||||
"types": types,
|
||||
"derived_from": parents[0] if parents else None,
|
||||
"triples": triples,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _find_by_type(dag, rdf_type):
|
||||
"""Find all event IDs that have the given rdf:type."""
|
||||
return [eid for eid, info in dag.items() if rdf_type in info["types"]]
|
||||
|
||||
|
||||
def _assert_chain(dag, chain_types):
|
||||
"""Assert that a linear wasDerivedFrom chain exists through the given types."""
|
||||
for i in range(1, len(chain_types)):
|
||||
parent_type = chain_types[i - 1]
|
||||
child_type = chain_types[i]
|
||||
parents = _find_by_type(dag, parent_type)
|
||||
children = _find_by_type(dag, child_type)
|
||||
assert parents, f"No entity with type {parent_type}"
|
||||
assert children, f"No entity with type {child_type}"
|
||||
# At least one child must derive from at least one parent
|
||||
linked = False
|
||||
for child_id in children:
|
||||
derived = dag[child_id]["derived_from"]
|
||||
if derived in parents:
|
||||
linked = True
|
||||
break
|
||||
assert linked, (
|
||||
f"No {child_type} derives from {parent_type}. "
|
||||
f"Children derive from: "
|
||||
f"{[dag[c]['derived_from'] for c in children]}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GraphRAG DAG structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagDagStructure:
|
||||
"""Verify: question → grounding → exploration → focus → synthesis"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clients(self):
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
graph_embeddings_client.query.return_value = [
|
||||
MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
|
||||
]
|
||||
triples_client.query_stream.return_value = [
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/e1"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="value"),
|
||||
)
|
||||
]
|
||||
triples_client.query.return_value = []
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
elif template_id == "kg-edge-scoring":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "score": 10} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text="Answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
rag = GraphRag(*mock_clients)
|
||||
events = []
|
||||
|
||||
async def explain_cb(triples, explain_id):
|
||||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb, edge_score_limit=0,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
|
||||
|
||||
_assert_chain(dag, [
|
||||
TG_GRAPH_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_FOCUS,
|
||||
TG_SYNTHESIS,
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentRAG DAG structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentRagDagStructure:
|
||||
"""Verify: question → grounding → exploration → synthesis"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clients(self):
|
||||
from trustgraph.schema import ChunkMatch
|
||||
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock(return_value="Chunk content.")
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.9),
|
||||
]
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text", text="Answer.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
rag = DocumentRag(*mock_clients)
|
||||
events = []
|
||||
|
||||
async def explain_cb(triples, explain_id):
|
||||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
|
||||
|
||||
_assert_chain(dag, [
|
||||
TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_SYNTHESIS,
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent DAG structure — tested via service.agent_request()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_processor(tools=None):
|
||||
processor = MagicMock()
|
||||
processor.max_iterations = 10
|
||||
processor.save_answer_content = AsyncMock()
|
||||
|
||||
def mock_session_uri(sid):
|
||||
return f"urn:trustgraph:agent:session:{sid}"
|
||||
processor.provenance_session_uri.side_effect = mock_session_uri
|
||||
|
||||
agent = MagicMock()
|
||||
agent.tools = tools or {}
|
||||
agent.additional_context = ""
|
||||
processor.agent = agent
|
||||
processor.aggregator = MagicMock()
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def _make_flow():
|
||||
producers = {}
|
||||
|
||||
def factory(name):
|
||||
if name not in producers:
|
||||
producers[name] = AsyncMock()
|
||||
return producers[name]
|
||||
|
||||
flow = MagicMock(side_effect=factory)
|
||||
return flow
|
||||
|
||||
|
||||
def _collect_agent_events(respond_mock):
|
||||
events = []
|
||||
for call in respond_mock.call_args_list:
|
||||
resp = call[0][0]
|
||||
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
|
||||
events.append({
|
||||
"explain_id": resp.explain_id,
|
||||
"triples": resp.explain_triples,
|
||||
})
|
||||
return events
|
||||
|
||||
|
||||
class TestAgentReactDagStructure:
|
||||
"""
|
||||
Via service.agent_request(), full two-iteration react chain:
|
||||
session → pattern-decision → iteration(1) → observation(1) → final
|
||||
|
||||
Iteration 1: tool call → observation
|
||||
Iteration 2: final answer
|
||||
"""
|
||||
|
||||
def _make_service(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "lookup"
|
||||
mock_tool.description = "Look things up"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="42")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = _make_processor(tools={"lookup": mock_tool})
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.react.types import Action, Final
|
||||
|
||||
service = self._make_service()
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Iteration 1: tool call → returns Action, triggers on_action + tool exec
|
||||
action = Action(
|
||||
thought="I need to look this up",
|
||||
name="lookup",
|
||||
arguments={"question": "6x7"},
|
||||
observation="",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react_iter1(on_action=None, **kwargs):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
action.observation = "42"
|
||||
return action
|
||||
|
||||
mock_am.react.side_effect = mock_react_iter1
|
||||
|
||||
request1 = AgentRequest(
|
||||
question="What is 6x7?",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
pattern="react",
|
||||
history=[],
|
||||
)
|
||||
|
||||
await service.agent_request(request1, respond, next_fn, flow)
|
||||
|
||||
# next_fn should have been called with updated history
|
||||
assert next_fn.called
|
||||
|
||||
# Iteration 2: final answer
|
||||
final = Final(thought="The answer is 42", final="42")
|
||||
next_request = next_fn.call_args[0][0]
|
||||
|
||||
with patch(
|
||||
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react_iter2(**kwargs):
|
||||
return final
|
||||
|
||||
mock_am.react.side_effect = mock_react_iter2
|
||||
|
||||
await service.agent_request(next_request, respond, next_fn, flow)
|
||||
|
||||
# Collect and verify DAG
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
analysis_ids = _find_by_type(dag, TG_ANALYSIS)
|
||||
observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
|
||||
final_ids = _find_by_type(dag, TG_CONCLUSION)
|
||||
|
||||
assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
|
||||
assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
|
||||
assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
|
||||
assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
|
||||
assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
|
||||
|
||||
# Full chain:
|
||||
# session → pattern-decision
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
|
||||
# pattern-decision → iteration(1)
|
||||
assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
# iteration(1) → observation(1)
|
||||
assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
|
||||
|
||||
# observation(1) → final
|
||||
assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
|
||||
|
||||
|
||||
class TestAgentPlanDagStructure:
|
||||
"""
|
||||
Via service.agent_request():
|
||||
session → pattern-decision → plan → step-result → synthesis
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
# Mock tool
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "knowledge-query"
|
||||
mock_tool.description = "Query KB"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="Found it")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = _make_processor(tools={"knowledge-query": mock_tool})
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
|
||||
# Mock prompt client
|
||||
mock_prompt_client = AsyncMock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_prompt(id, variables=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if id == "plan-create":
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
|
||||
)
|
||||
elif id == "plan-step-execute":
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object={"tool": "knowledge-query", "arguments": {"question": "test"}},
|
||||
)
|
||||
elif id == "plan-synthesise":
|
||||
return PromptResult(response_type="text", text="Final answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
mock_prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Iteration 1: planning
|
||||
request1 = AgentRequest(
|
||||
question="Test?",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
pattern="plan-then-execute",
|
||||
history=[],
|
||||
)
|
||||
await service.agent_request(request1, respond, next_fn, flow)
|
||||
|
||||
# Iteration 2: execute step (next_fn was called with updated request)
|
||||
assert next_fn.called
|
||||
next_request = next_fn.call_args[0][0]
|
||||
|
||||
# Iteration 3: all steps done → synthesis
|
||||
# Simulate completed step in history
|
||||
next_request.history[-1].plan[0].status = "completed"
|
||||
next_request.history[-1].plan[0].result = "Found it"
|
||||
|
||||
await service.agent_request(next_request, respond, next_fn, flow)
|
||||
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
|
||||
synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
|
||||
|
||||
assert len(session_ids) == 1
|
||||
assert len(pd_ids) == 1
|
||||
assert len(plan_ids) == 1
|
||||
assert len(synthesis_ids) == 1
|
||||
|
||||
# Chain: session → pattern-decision → plan → ... → synthesis
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
|
||||
class TestAgentSupervisorDagStructure:
|
||||
"""
|
||||
Via service.agent_request():
|
||||
session → pattern-decision → decomposition → (fan-out)
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = _make_processor()
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=["Goal A", "Goal B"],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = AgentRequest(
|
||||
question="Research quantum computing",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=str(uuid.uuid4()),
|
||||
pattern="supervisor",
|
||||
history=[],
|
||||
)
|
||||
|
||||
await service.agent_request(request, respond, next_fn, flow)
|
||||
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
|
||||
|
||||
assert len(session_ids) == 1
|
||||
assert len(pd_ids) == 1
|
||||
assert len(decomp_ids) == 1
|
||||
|
||||
# Chain: session → pattern-decision → decomposition
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
# Fan-out should have been called
|
||||
assert next_fn.call_count == 2 # One per goal
|
||||
|
|
@ -223,7 +223,7 @@ class TestDerivedEntityTriples:
|
|||
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
|
||||
|
||||
def test_chunk_entity_has_chunk_type(self):
|
||||
def test_chunk_entity_has_message_type(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"chunker", "1.0",
|
||||
|
|
|
|||
|
|
@ -304,14 +304,14 @@ class TestStreamingTypes:
|
|||
|
||||
assert chunk.content == "thinking..."
|
||||
assert chunk.end_of_message is False
|
||||
assert chunk.chunk_type == "thought"
|
||||
assert chunk.message_type == "thought"
|
||||
|
||||
def test_agent_observation_creation(self):
|
||||
"""Test creating AgentObservation chunk"""
|
||||
chunk = AgentObservation(content="observing...", end_of_message=False)
|
||||
|
||||
assert chunk.content == "observing..."
|
||||
assert chunk.chunk_type == "observation"
|
||||
assert chunk.message_type == "observation"
|
||||
|
||||
def test_agent_answer_creation(self):
|
||||
"""Test creating AgentAnswer chunk"""
|
||||
|
|
@ -324,7 +324,7 @@ class TestStreamingTypes:
|
|||
assert chunk.content == "answer"
|
||||
assert chunk.end_of_message is True
|
||||
assert chunk.end_of_dialog is True
|
||||
assert chunk.chunk_type == "final-answer"
|
||||
assert chunk.message_type == "final-answer"
|
||||
|
||||
def test_rag_chunk_creation(self):
|
||||
"""Test creating RAGChunk"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue