mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
593 lines
21 KiB
Python
593 lines
21 KiB
Python
|
|
"""
|
||
|
|
DAG structure tests for provenance chains.
|
||
|
|
|
||
|
|
Verifies that the wasDerivedFrom chain has the expected shape for each
|
||
|
|
service. These tests catch structural regressions when new entities are
|
||
|
|
inserted into the chain (e.g. PatternDecision between session and first
|
||
|
|
iteration).
|
||
|
|
|
||
|
|
Expected chains:
|
||
|
|
|
||
|
|
GraphRAG: question → grounding → exploration → focus → synthesis
|
||
|
|
DocumentRAG: question → grounding → exploration → synthesis
|
||
|
|
Agent React: session → pattern-decision → iteration → (observation → iteration)* → final
|
||
|
|
Agent Plan: session → pattern-decision → plan → step-result(s) → synthesis
|
||
|
|
Agent Super: session → pattern-decision → decomposition → (fan-out) → finding(s) → synthesis
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import uuid
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
|
||
|
|
from trustgraph.schema import (
|
||
|
|
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||
|
|
Triple, Term, IRI, LITERAL,
|
||
|
|
)
|
||
|
|
from trustgraph.base import PromptResult
|
||
|
|
|
||
|
|
from trustgraph.provenance.namespaces import (
|
||
|
|
RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
|
||
|
|
TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
|
||
|
|
TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||
|
|
TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
|
||
|
|
TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
|
||
|
|
TG_OBSERVATION_TYPE,
|
||
|
|
TG_PATTERN, TG_TASK_TYPE,
|
||
|
|
)
|
||
|
|
|
||
|
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||
|
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Helpers
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def _collect_events(events):
|
||
|
|
"""Build a dict of explain_id → {types, derived_from, triples}."""
|
||
|
|
result = {}
|
||
|
|
for ev in events:
|
||
|
|
eid = ev["explain_id"]
|
||
|
|
triples = ev["triples"]
|
||
|
|
types = {
|
||
|
|
t.o.iri for t in triples
|
||
|
|
if t.s.iri == eid and t.p.iri == RDF_TYPE
|
||
|
|
}
|
||
|
|
parents = [
|
||
|
|
t.o.iri for t in triples
|
||
|
|
if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
|
||
|
|
]
|
||
|
|
result[eid] = {
|
||
|
|
"types": types,
|
||
|
|
"derived_from": parents[0] if parents else None,
|
||
|
|
"triples": triples,
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
def _find_by_type(dag, rdf_type):
|
||
|
|
"""Find all event IDs that have the given rdf:type."""
|
||
|
|
return [eid for eid, info in dag.items() if rdf_type in info["types"]]
|
||
|
|
|
||
|
|
|
||
|
|
def _assert_chain(dag, chain_types):
|
||
|
|
"""Assert that a linear wasDerivedFrom chain exists through the given types."""
|
||
|
|
for i in range(1, len(chain_types)):
|
||
|
|
parent_type = chain_types[i - 1]
|
||
|
|
child_type = chain_types[i]
|
||
|
|
parents = _find_by_type(dag, parent_type)
|
||
|
|
children = _find_by_type(dag, child_type)
|
||
|
|
assert parents, f"No entity with type {parent_type}"
|
||
|
|
assert children, f"No entity with type {child_type}"
|
||
|
|
# At least one child must derive from at least one parent
|
||
|
|
linked = False
|
||
|
|
for child_id in children:
|
||
|
|
derived = dag[child_id]["derived_from"]
|
||
|
|
if derived in parents:
|
||
|
|
linked = True
|
||
|
|
break
|
||
|
|
assert linked, (
|
||
|
|
f"No {child_type} derives from {parent_type}. "
|
||
|
|
f"Children derive from: "
|
||
|
|
f"{[dag[c]['derived_from'] for c in children]}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# GraphRAG DAG structure
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestGraphRagDagStructure:
|
||
|
|
"""Verify: question → grounding → exploration → focus → synthesis"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_clients(self):
|
||
|
|
prompt_client = AsyncMock()
|
||
|
|
embeddings_client = AsyncMock()
|
||
|
|
graph_embeddings_client = AsyncMock()
|
||
|
|
triples_client = AsyncMock()
|
||
|
|
|
||
|
|
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||
|
|
graph_embeddings_client.query.return_value = [
|
||
|
|
MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
|
||
|
|
]
|
||
|
|
triples_client.query_stream.return_value = [
|
||
|
|
Triple(
|
||
|
|
s=Term(type=IRI, iri="http://example.com/e1"),
|
||
|
|
p=Term(type=IRI, iri="http://example.com/p"),
|
||
|
|
o=Term(type=LITERAL, value="value"),
|
||
|
|
)
|
||
|
|
]
|
||
|
|
triples_client.query.return_value = []
|
||
|
|
|
||
|
|
async def mock_prompt(template_id, variables=None, **kwargs):
|
||
|
|
if template_id == "extract-concepts":
|
||
|
|
return PromptResult(response_type="text", text="concept")
|
||
|
|
elif template_id == "kg-edge-scoring":
|
||
|
|
edges = variables.get("knowledge", [])
|
||
|
|
return PromptResult(
|
||
|
|
response_type="jsonl",
|
||
|
|
objects=[{"id": e["id"], "score": 10} for e in edges],
|
||
|
|
)
|
||
|
|
elif template_id == "kg-edge-reasoning":
|
||
|
|
edges = variables.get("knowledge", [])
|
||
|
|
return PromptResult(
|
||
|
|
response_type="jsonl",
|
||
|
|
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
|
||
|
|
)
|
||
|
|
elif template_id == "kg-synthesis":
|
||
|
|
return PromptResult(response_type="text", text="Answer.")
|
||
|
|
return PromptResult(response_type="text", text="")
|
||
|
|
|
||
|
|
prompt_client.prompt.side_effect = mock_prompt
|
||
|
|
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_dag_chain(self, mock_clients):
|
||
|
|
rag = GraphRag(*mock_clients)
|
||
|
|
events = []
|
||
|
|
|
||
|
|
async def explain_cb(triples, explain_id):
|
||
|
|
events.append({"explain_id": explain_id, "triples": triples})
|
||
|
|
|
||
|
|
await rag.query(
|
||
|
|
query="test", explain_callback=explain_cb, edge_score_limit=0,
|
||
|
|
)
|
||
|
|
|
||
|
|
dag = _collect_events(events)
|
||
|
|
assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
|
||
|
|
|
||
|
|
_assert_chain(dag, [
|
||
|
|
TG_GRAPH_RAG_QUESTION,
|
||
|
|
TG_GROUNDING,
|
||
|
|
TG_EXPLORATION,
|
||
|
|
TG_FOCUS,
|
||
|
|
TG_SYNTHESIS,
|
||
|
|
])
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# DocumentRAG DAG structure
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestDocumentRagDagStructure:
|
||
|
|
"""Verify: question → grounding → exploration → synthesis"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_clients(self):
|
||
|
|
from trustgraph.schema import ChunkMatch
|
||
|
|
|
||
|
|
prompt_client = AsyncMock()
|
||
|
|
embeddings_client = AsyncMock()
|
||
|
|
doc_embeddings_client = AsyncMock()
|
||
|
|
fetch_chunk = AsyncMock(return_value="Chunk content.")
|
||
|
|
|
||
|
|
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||
|
|
doc_embeddings_client.query.return_value = [
|
||
|
|
ChunkMatch(chunk_id="doc/c1", score=0.9),
|
||
|
|
]
|
||
|
|
|
||
|
|
async def mock_prompt(template_id, variables=None, **kwargs):
|
||
|
|
if template_id == "extract-concepts":
|
||
|
|
return PromptResult(response_type="text", text="concept")
|
||
|
|
return PromptResult(response_type="text", text="")
|
||
|
|
|
||
|
|
prompt_client.prompt.side_effect = mock_prompt
|
||
|
|
prompt_client.document_prompt.return_value = PromptResult(
|
||
|
|
response_type="text", text="Answer.",
|
||
|
|
)
|
||
|
|
|
||
|
|
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_dag_chain(self, mock_clients):
|
||
|
|
rag = DocumentRag(*mock_clients)
|
||
|
|
events = []
|
||
|
|
|
||
|
|
async def explain_cb(triples, explain_id):
|
||
|
|
events.append({"explain_id": explain_id, "triples": triples})
|
||
|
|
|
||
|
|
await rag.query(
|
||
|
|
query="test", explain_callback=explain_cb,
|
||
|
|
)
|
||
|
|
|
||
|
|
dag = _collect_events(events)
|
||
|
|
assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
|
||
|
|
|
||
|
|
_assert_chain(dag, [
|
||
|
|
TG_DOC_RAG_QUESTION,
|
||
|
|
TG_GROUNDING,
|
||
|
|
TG_EXPLORATION,
|
||
|
|
TG_SYNTHESIS,
|
||
|
|
])
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Agent DAG structure — tested via service.agent_request()
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def _make_processor(tools=None):
|
||
|
|
processor = MagicMock()
|
||
|
|
processor.max_iterations = 10
|
||
|
|
processor.save_answer_content = AsyncMock()
|
||
|
|
|
||
|
|
def mock_session_uri(sid):
|
||
|
|
return f"urn:trustgraph:agent:session:{sid}"
|
||
|
|
processor.provenance_session_uri.side_effect = mock_session_uri
|
||
|
|
|
||
|
|
agent = MagicMock()
|
||
|
|
agent.tools = tools or {}
|
||
|
|
agent.additional_context = ""
|
||
|
|
processor.agent = agent
|
||
|
|
processor.aggregator = MagicMock()
|
||
|
|
|
||
|
|
return processor
|
||
|
|
|
||
|
|
|
||
|
|
def _make_flow():
|
||
|
|
producers = {}
|
||
|
|
|
||
|
|
def factory(name):
|
||
|
|
if name not in producers:
|
||
|
|
producers[name] = AsyncMock()
|
||
|
|
return producers[name]
|
||
|
|
|
||
|
|
flow = MagicMock(side_effect=factory)
|
||
|
|
return flow
|
||
|
|
|
||
|
|
|
||
|
|
def _collect_agent_events(respond_mock):
|
||
|
|
events = []
|
||
|
|
for call in respond_mock.call_args_list:
|
||
|
|
resp = call[0][0]
|
||
|
|
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
|
||
|
|
events.append({
|
||
|
|
"explain_id": resp.explain_id,
|
||
|
|
"triples": resp.explain_triples,
|
||
|
|
})
|
||
|
|
return events
|
||
|
|
|
||
|
|
|
||
|
|
class TestAgentReactDagStructure:
|
||
|
|
"""
|
||
|
|
Via service.agent_request(), full two-iteration react chain:
|
||
|
|
session → pattern-decision → iteration(1) → observation(1) → final
|
||
|
|
|
||
|
|
Iteration 1: tool call → observation
|
||
|
|
Iteration 2: final answer
|
||
|
|
"""
|
||
|
|
|
||
|
|
def _make_service(self):
|
||
|
|
from trustgraph.agent.orchestrator.service import Processor
|
||
|
|
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||
|
|
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||
|
|
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||
|
|
|
||
|
|
mock_tool = MagicMock()
|
||
|
|
mock_tool.name = "lookup"
|
||
|
|
mock_tool.description = "Look things up"
|
||
|
|
mock_tool.arguments = []
|
||
|
|
mock_tool.groups = []
|
||
|
|
mock_tool.states = {}
|
||
|
|
mock_tool_impl = AsyncMock(return_value="42")
|
||
|
|
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||
|
|
|
||
|
|
processor = _make_processor(tools={"lookup": mock_tool})
|
||
|
|
|
||
|
|
service = Processor.__new__(Processor)
|
||
|
|
service.max_iterations = 10
|
||
|
|
service.save_answer_content = AsyncMock()
|
||
|
|
service.provenance_session_uri = processor.provenance_session_uri
|
||
|
|
service.agent = processor.agent
|
||
|
|
service.aggregator = processor.aggregator
|
||
|
|
|
||
|
|
service.react_pattern = ReactPattern(service)
|
||
|
|
service.plan_pattern = PlanThenExecutePattern(service)
|
||
|
|
service.supervisor_pattern = SupervisorPattern(service)
|
||
|
|
service.meta_router = None
|
||
|
|
|
||
|
|
return service
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_dag_chain(self):
|
||
|
|
from trustgraph.agent.react.types import Action, Final
|
||
|
|
|
||
|
|
service = self._make_service()
|
||
|
|
|
||
|
|
respond = AsyncMock()
|
||
|
|
next_fn = AsyncMock()
|
||
|
|
flow = _make_flow()
|
||
|
|
session_id = str(uuid.uuid4())
|
||
|
|
|
||
|
|
# Iteration 1: tool call → returns Action, triggers on_action + tool exec
|
||
|
|
action = Action(
|
||
|
|
thought="I need to look this up",
|
||
|
|
name="lookup",
|
||
|
|
arguments={"question": "6x7"},
|
||
|
|
observation="",
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch(
|
||
|
|
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
|
||
|
|
) as MockAM:
|
||
|
|
mock_am = AsyncMock()
|
||
|
|
MockAM.return_value = mock_am
|
||
|
|
|
||
|
|
async def mock_react_iter1(on_action=None, **kwargs):
|
||
|
|
if on_action:
|
||
|
|
await on_action(action)
|
||
|
|
action.observation = "42"
|
||
|
|
return action
|
||
|
|
|
||
|
|
mock_am.react.side_effect = mock_react_iter1
|
||
|
|
|
||
|
|
request1 = AgentRequest(
|
||
|
|
question="What is 6x7?",
|
||
|
|
user="testuser",
|
||
|
|
collection="default",
|
||
|
|
streaming=False,
|
||
|
|
session_id=session_id,
|
||
|
|
pattern="react",
|
||
|
|
history=[],
|
||
|
|
)
|
||
|
|
|
||
|
|
await service.agent_request(request1, respond, next_fn, flow)
|
||
|
|
|
||
|
|
# next_fn should have been called with updated history
|
||
|
|
assert next_fn.called
|
||
|
|
|
||
|
|
# Iteration 2: final answer
|
||
|
|
final = Final(thought="The answer is 42", final="42")
|
||
|
|
next_request = next_fn.call_args[0][0]
|
||
|
|
|
||
|
|
with patch(
|
||
|
|
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
|
||
|
|
) as MockAM:
|
||
|
|
mock_am = AsyncMock()
|
||
|
|
MockAM.return_value = mock_am
|
||
|
|
|
||
|
|
async def mock_react_iter2(**kwargs):
|
||
|
|
return final
|
||
|
|
|
||
|
|
mock_am.react.side_effect = mock_react_iter2
|
||
|
|
|
||
|
|
await service.agent_request(next_request, respond, next_fn, flow)
|
||
|
|
|
||
|
|
# Collect and verify DAG
|
||
|
|
events = _collect_agent_events(respond)
|
||
|
|
dag = _collect_events(events)
|
||
|
|
|
||
|
|
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||
|
|
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||
|
|
analysis_ids = _find_by_type(dag, TG_ANALYSIS)
|
||
|
|
observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
|
||
|
|
final_ids = _find_by_type(dag, TG_CONCLUSION)
|
||
|
|
|
||
|
|
assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
|
||
|
|
assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
|
||
|
|
assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
|
||
|
|
assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
|
||
|
|
assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
|
||
|
|
|
||
|
|
# Full chain:
|
||
|
|
# session → pattern-decision
|
||
|
|
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||
|
|
|
||
|
|
# pattern-decision → iteration(1)
|
||
|
|
assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
|
||
|
|
|
||
|
|
# iteration(1) → observation(1)
|
||
|
|
assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
|
||
|
|
|
||
|
|
# observation(1) → final
|
||
|
|
assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
|
||
|
|
|
||
|
|
|
||
|
|
class TestAgentPlanDagStructure:
|
||
|
|
"""
|
||
|
|
Via service.agent_request():
|
||
|
|
session → pattern-decision → plan → step-result → synthesis
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_dag_chain(self):
|
||
|
|
from trustgraph.agent.orchestrator.service import Processor
|
||
|
|
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||
|
|
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||
|
|
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||
|
|
|
||
|
|
# Mock tool
|
||
|
|
mock_tool = MagicMock()
|
||
|
|
mock_tool.name = "knowledge-query"
|
||
|
|
mock_tool.description = "Query KB"
|
||
|
|
mock_tool.arguments = []
|
||
|
|
mock_tool.groups = []
|
||
|
|
mock_tool.states = {}
|
||
|
|
mock_tool_impl = AsyncMock(return_value="Found it")
|
||
|
|
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||
|
|
|
||
|
|
processor = _make_processor(tools={"knowledge-query": mock_tool})
|
||
|
|
|
||
|
|
service = Processor.__new__(Processor)
|
||
|
|
service.max_iterations = 10
|
||
|
|
service.save_answer_content = AsyncMock()
|
||
|
|
service.provenance_session_uri = processor.provenance_session_uri
|
||
|
|
service.agent = processor.agent
|
||
|
|
service.aggregator = processor.aggregator
|
||
|
|
|
||
|
|
service.react_pattern = ReactPattern(service)
|
||
|
|
service.plan_pattern = PlanThenExecutePattern(service)
|
||
|
|
service.supervisor_pattern = SupervisorPattern(service)
|
||
|
|
service.meta_router = None
|
||
|
|
|
||
|
|
respond = AsyncMock()
|
||
|
|
next_fn = AsyncMock()
|
||
|
|
flow = _make_flow()
|
||
|
|
|
||
|
|
# Mock prompt client
|
||
|
|
mock_prompt_client = AsyncMock()
|
||
|
|
|
||
|
|
call_count = 0
|
||
|
|
|
||
|
|
async def mock_prompt(id, variables=None, **kwargs):
|
||
|
|
nonlocal call_count
|
||
|
|
call_count += 1
|
||
|
|
if id == "plan-create":
|
||
|
|
return PromptResult(
|
||
|
|
response_type="jsonl",
|
||
|
|
objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
|
||
|
|
)
|
||
|
|
elif id == "plan-step-execute":
|
||
|
|
return PromptResult(
|
||
|
|
response_type="json",
|
||
|
|
object={"tool": "knowledge-query", "arguments": {"question": "test"}},
|
||
|
|
)
|
||
|
|
elif id == "plan-synthesise":
|
||
|
|
return PromptResult(response_type="text", text="Final answer.")
|
||
|
|
return PromptResult(response_type="text", text="")
|
||
|
|
|
||
|
|
mock_prompt_client.prompt.side_effect = mock_prompt
|
||
|
|
|
||
|
|
def flow_factory(name):
|
||
|
|
if name == "prompt-request":
|
||
|
|
return mock_prompt_client
|
||
|
|
return AsyncMock()
|
||
|
|
flow.side_effect = flow_factory
|
||
|
|
|
||
|
|
session_id = str(uuid.uuid4())
|
||
|
|
|
||
|
|
# Iteration 1: planning
|
||
|
|
request1 = AgentRequest(
|
||
|
|
question="Test?",
|
||
|
|
user="testuser",
|
||
|
|
collection="default",
|
||
|
|
streaming=False,
|
||
|
|
session_id=session_id,
|
||
|
|
pattern="plan-then-execute",
|
||
|
|
history=[],
|
||
|
|
)
|
||
|
|
await service.agent_request(request1, respond, next_fn, flow)
|
||
|
|
|
||
|
|
# Iteration 2: execute step (next_fn was called with updated request)
|
||
|
|
assert next_fn.called
|
||
|
|
next_request = next_fn.call_args[0][0]
|
||
|
|
|
||
|
|
# Iteration 3: all steps done → synthesis
|
||
|
|
# Simulate completed step in history
|
||
|
|
next_request.history[-1].plan[0].status = "completed"
|
||
|
|
next_request.history[-1].plan[0].result = "Found it"
|
||
|
|
|
||
|
|
await service.agent_request(next_request, respond, next_fn, flow)
|
||
|
|
|
||
|
|
events = _collect_agent_events(respond)
|
||
|
|
dag = _collect_events(events)
|
||
|
|
|
||
|
|
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||
|
|
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||
|
|
plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
|
||
|
|
synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
|
||
|
|
|
||
|
|
assert len(session_ids) == 1
|
||
|
|
assert len(pd_ids) == 1
|
||
|
|
assert len(plan_ids) == 1
|
||
|
|
assert len(synthesis_ids) == 1
|
||
|
|
|
||
|
|
# Chain: session → pattern-decision → plan → ... → synthesis
|
||
|
|
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||
|
|
assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
|
||
|
|
|
||
|
|
|
||
|
|
class TestAgentSupervisorDagStructure:
|
||
|
|
"""
|
||
|
|
Via service.agent_request():
|
||
|
|
session → pattern-decision → decomposition → (fan-out)
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_dag_chain(self):
|
||
|
|
from trustgraph.agent.orchestrator.service import Processor
|
||
|
|
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||
|
|
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||
|
|
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||
|
|
|
||
|
|
processor = _make_processor()
|
||
|
|
|
||
|
|
service = Processor.__new__(Processor)
|
||
|
|
service.max_iterations = 10
|
||
|
|
service.save_answer_content = AsyncMock()
|
||
|
|
service.provenance_session_uri = processor.provenance_session_uri
|
||
|
|
service.agent = processor.agent
|
||
|
|
service.aggregator = processor.aggregator
|
||
|
|
|
||
|
|
service.react_pattern = ReactPattern(service)
|
||
|
|
service.plan_pattern = PlanThenExecutePattern(service)
|
||
|
|
service.supervisor_pattern = SupervisorPattern(service)
|
||
|
|
service.meta_router = None
|
||
|
|
|
||
|
|
respond = AsyncMock()
|
||
|
|
next_fn = AsyncMock()
|
||
|
|
flow = _make_flow()
|
||
|
|
|
||
|
|
mock_prompt_client = AsyncMock()
|
||
|
|
mock_prompt_client.prompt.return_value = PromptResult(
|
||
|
|
response_type="jsonl",
|
||
|
|
objects=["Goal A", "Goal B"],
|
||
|
|
)
|
||
|
|
|
||
|
|
def flow_factory(name):
|
||
|
|
if name == "prompt-request":
|
||
|
|
return mock_prompt_client
|
||
|
|
return AsyncMock()
|
||
|
|
flow.side_effect = flow_factory
|
||
|
|
|
||
|
|
request = AgentRequest(
|
||
|
|
question="Research quantum computing",
|
||
|
|
user="testuser",
|
||
|
|
collection="default",
|
||
|
|
streaming=False,
|
||
|
|
session_id=str(uuid.uuid4()),
|
||
|
|
pattern="supervisor",
|
||
|
|
history=[],
|
||
|
|
)
|
||
|
|
|
||
|
|
await service.agent_request(request, respond, next_fn, flow)
|
||
|
|
|
||
|
|
events = _collect_agent_events(respond)
|
||
|
|
dag = _collect_events(events)
|
||
|
|
|
||
|
|
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||
|
|
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||
|
|
decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
|
||
|
|
|
||
|
|
assert len(session_ids) == 1
|
||
|
|
assert len(pd_ids) == 1
|
||
|
|
assert len(decomp_ids) == 1
|
||
|
|
|
||
|
|
# Chain: session → pattern-decision → decomposition
|
||
|
|
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||
|
|
assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
|
||
|
|
|
||
|
|
# Fan-out should have been called
|
||
|
|
assert next_fn.call_count == 2 # One per goal
|