Update tests for agent-orchestrator (#745)

Add 96 tests covering the orchestrator's aggregation, provenance,
routing, and explainability parsing. These verify the supervisor
fan-out/fan-in lifecycle, the new RDF provenance types
(Decomposition, Finding, Plan, StepResult, Synthesis), and their
round-trip through the wire format.

Unit tests (84):
- Aggregator: register, record completion, peek, build synthesis,
  cleanup
- Provenance triple builders: types, provenance links,
  goals/steps, labels
- Explainability parsing: from_triples dispatch, field extraction
  for all new entity types, precedence over existing types
- PatternBase: is_subagent detection, emit_subagent_completion
  message shape
- Completion dispatch: detection logic, full aggregator
  integration flow, synthesis request not re-intercepted as
  completion
- MetaRouter: task type identification, pattern selection,
  valid_patterns constraints, fallback on LLM error or unknown
  response

Contract tests (12):
- Orchestration fields on AgentRequest round-trip correctly
- subagent-completion and synthesise step types in request
  history
- Plan steps with status and dependencies
- Provenance triple builder → wire format → from_triples
  round-trip for all five new entity types
This commit is contained in:
cybermaggedon 2026-03-31 13:12:26 +01:00 committed by GitHub
parent 7b734148b3
commit 816a8cfcf6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1517 additions and 0 deletions

View file

@ -0,0 +1,177 @@
"""
Contract tests for orchestrator message schemas.
Verifies that AgentRequest/AgentStep with orchestration fields
serialise and deserialise correctly through the Pulsar schema layer.
"""
import pytest
import json
from trustgraph.schema import AgentRequest, AgentStep, PlanStep
@pytest.mark.contract
class TestOrchestrationFieldContracts:
"""Contract tests for orchestration fields on AgentRequest."""
def test_agent_request_orchestration_fields_roundtrip(self):
req = AgentRequest(
question="Test question",
user="testuser",
collection="default",
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
pattern="react",
task_type="research",
framing="Focus on accuracy",
conversation_id="conv-456",
)
assert req.correlation_id == "corr-123"
assert req.parent_session_id == "parent-sess"
assert req.subagent_goal == "What is X?"
assert req.expected_siblings == 4
assert req.pattern == "react"
assert req.task_type == "research"
assert req.framing == "Focus on accuracy"
assert req.conversation_id == "conv-456"
def test_agent_request_orchestration_fields_default_empty(self):
req = AgentRequest(
question="Test question",
user="testuser",
)
assert req.correlation_id == ""
assert req.parent_session_id == ""
assert req.subagent_goal == ""
assert req.expected_siblings == 0
assert req.pattern == ""
assert req.task_type == ""
assert req.framing == ""
@pytest.mark.contract
class TestSubagentCompletionStepContract:
"""Contract tests for subagent-completion step type."""
def test_subagent_completion_step_fields(self):
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation="The answer text",
step_type="subagent-completion",
)
assert step.step_type == "subagent-completion"
assert step.observation == "The answer text"
assert step.thought == "Subagent completed"
assert step.action == "complete"
def test_subagent_completion_in_request_history(self):
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation="answer",
step_type="subagent-completion",
)
req = AgentRequest(
question="goal",
user="testuser",
correlation_id="corr-123",
history=[step],
)
assert len(req.history) == 1
assert req.history[0].step_type == "subagent-completion"
assert req.history[0].observation == "answer"
@pytest.mark.contract
class TestSynthesisStepContract:
"""Contract tests for synthesis step type with subagent_results."""
def test_synthesis_step_with_results(self):
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
step = AgentStep(
thought="All subagents completed",
action="aggregate",
arguments={},
observation=json.dumps(results),
step_type="synthesise",
subagent_results=results,
)
assert step.step_type == "synthesise"
assert step.subagent_results == results
assert json.loads(step.observation) == results
def test_synthesis_request_matches_supervisor_expectations(self):
"""The synthesis request built by the aggregator must be
recognisable by SupervisorPattern._synthesise()."""
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
step = AgentStep(
thought="All subagents completed",
action="aggregate",
arguments={},
observation=json.dumps(results),
step_type="synthesise",
subagent_results=results,
)
req = AgentRequest(
question="Original question",
user="testuser",
pattern="supervisor",
correlation_id="",
session_id="parent-sess",
history=[step],
)
# SupervisorPattern checks for step_type='synthesise' with
# subagent_results
has_results = bool(
req.history
and any(
getattr(h, 'step_type', '') == 'synthesise'
and getattr(h, 'subagent_results', None)
for h in req.history
)
)
assert has_results
# Pattern must be supervisor
assert req.pattern == "supervisor"
# Correlation ID must be empty (not re-intercepted)
assert req.correlation_id == ""
@pytest.mark.contract
class TestPlanStepContract:
"""Contract tests for plan steps in history."""
def test_plan_step_in_history(self):
plan = [
PlanStep(goal="Step 1", tool_hint="knowledge-query",
depends_on=[], status="completed", result="done"),
PlanStep(goal="Step 2", tool_hint="",
depends_on=[0], status="pending", result=""),
]
step = AgentStep(
thought="Created plan",
action="plan",
step_type="plan",
plan=plan,
)
assert step.step_type == "plan"
assert len(step.plan) == 2
assert step.plan[0].goal == "Step 1"
assert step.plan[0].status == "completed"
assert step.plan[1].depends_on == [0]

View file

@ -0,0 +1,129 @@
"""
Contract tests for provenance triple wire format verifies that triples
built by the provenance library can be parsed by the explainability API
through the wire format conversion.
"""
import pytest
from trustgraph.schema import IRI, LITERAL
from trustgraph.provenance import (
agent_decomposition_triples,
agent_finding_triples,
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
)
from trustgraph.api.explainability import (
ExplainEntity,
Decomposition,
Finding,
Plan,
StepResult,
Synthesis,
wire_triples_to_tuples,
)
def _triples_to_wire(triples):
"""Convert provenance Triple objects to the wire format dicts
that the gateway/socket client would produce."""
wire = []
for t in triples:
entry = {
"s": _term_to_wire(t.s),
"p": _term_to_wire(t.p),
"o": _term_to_wire(t.o),
}
wire.append(entry)
return wire
def _term_to_wire(term):
"""Convert a Term to wire format dict."""
if term.type == IRI:
return {"t": "i", "i": term.iri}
elif term.type == LITERAL:
return {"t": "l", "v": term.value}
return {"t": "l", "v": str(term)}
def _roundtrip(triples, uri):
"""Convert triples through wire format and parse via from_triples."""
wire = _triples_to_wire(triples)
tuples = wire_triples_to_tuples(wire)
return ExplainEntity.from_triples(uri, tuples)
@pytest.mark.contract
class TestDecompositionWireFormat:
def test_roundtrip(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session",
["What is X?", "What is Y?"],
)
entity = _roundtrip(triples, "urn:decompose")
assert isinstance(entity, Decomposition)
assert set(entity.goals) == {"What is X?", "What is Y?"}
@pytest.mark.contract
class TestFindingWireFormat:
def test_roundtrip(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
document_id="urn:doc/finding",
)
entity = _roundtrip(triples, "urn:finding")
assert isinstance(entity, Finding)
assert entity.goal == "What is X?"
assert entity.document == "urn:doc/finding"
@pytest.mark.contract
class TestPlanWireFormat:
def test_roundtrip(self):
triples = agent_plan_triples(
"urn:plan", "urn:session",
["Step 1", "Step 2", "Step 3"],
)
entity = _roundtrip(triples, "urn:plan")
assert isinstance(entity, Plan)
assert set(entity.steps) == {"Step 1", "Step 2", "Step 3"}
@pytest.mark.contract
class TestStepResultWireFormat:
def test_roundtrip(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
document_id="urn:doc/step",
)
entity = _roundtrip(triples, "urn:step")
assert isinstance(entity, StepResult)
assert entity.step == "Define X"
assert entity.document == "urn:doc/step"
@pytest.mark.contract
class TestSynthesisWireFormat:
def test_roundtrip(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
document_id="urn:doc/synthesis",
)
entity = _roundtrip(triples, "urn:synthesis")
assert isinstance(entity, Synthesis)
assert entity.document == "urn:doc/synthesis"

View file

@ -0,0 +1,216 @@
"""
Unit tests for the Aggregator tracks fan-out correlations and triggers
synthesis when all subagents complete.
"""
import time
import pytest
from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(question="Test question", user="testuser",
collection="default", streaming=False,
session_id="parent-session", task_type="research",
framing="test framing", conversation_id="conv-1"):
return AgentRequest(
question=question,
user=user,
collection=collection,
streaming=streaming,
session_id=session_id,
task_type=task_type,
framing=framing,
conversation_id=conversation_id,
)
class TestRegisterFanout:
def test_stores_correlation_entry(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 3)
assert "corr-1" in agg.correlations
entry = agg.correlations["corr-1"]
assert entry["parent_session_id"] == "parent-1"
assert entry["expected"] == 3
assert entry["results"] == {}
def test_stores_request_template(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
entry = agg.correlations["corr-1"]
assert entry["request_template"] is template
def test_records_creation_time(self):
agg = Aggregator()
before = time.time()
agg.register_fanout("corr-1", "parent-1", 2)
after = time.time()
created = agg.correlations["corr-1"]["created_at"]
assert before <= created <= after
class TestRecordCompletion:
def test_returns_false_until_all_done(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 3)
assert agg.record_completion("corr-1", "goal-a", "answer-a") is False
assert agg.record_completion("corr-1", "goal-b", "answer-b") is False
assert agg.record_completion("corr-1", "goal-c", "answer-c") is True
def test_returns_none_for_unknown_correlation(self):
agg = Aggregator()
result = agg.record_completion("unknown", "goal", "answer")
assert result is None
def test_stores_results_by_goal(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 2)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
results = agg.correlations["corr-1"]["results"]
assert results["goal-a"] == "answer-a"
assert results["goal-b"] == "answer-b"
def test_single_subagent(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 1)
assert agg.record_completion("corr-1", "goal-a", "answer") is True
class TestGetOriginalRequest:
def test_peeks_without_consuming(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
result = agg.get_original_request("corr-1")
assert result is template
# Entry still exists
assert "corr-1" in agg.correlations
def test_returns_none_for_unknown(self):
agg = Aggregator()
assert agg.get_original_request("unknown") is None
class TestBuildSynthesisRequest:
def test_builds_correct_request(self):
agg = Aggregator()
template = _make_request(
question="Original question",
streaming=True,
task_type="risk-assessment",
framing="Assess risks",
)
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
assert req.question == "Original question"
assert req.pattern == "supervisor"
assert req.session_id == "parent-1"
assert req.correlation_id == "" # Must be empty
assert req.streaming == True
assert req.task_type == "risk-assessment"
assert req.framing == "Assess risks"
def test_synthesis_step_in_history(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# Last history step should be the synthesis step
assert len(req.history) >= 1
synth_step = req.history[-1]
assert synth_step.step_type == "synthesise"
assert synth_step.subagent_results == {
"goal-a": "answer-a",
"goal-b": "answer-b",
}
def test_consumes_correlation_entry(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 1,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# Entry should be removed
assert "corr-1" not in agg.correlations
def test_raises_for_unknown_correlation(self):
agg = Aggregator()
with pytest.raises(RuntimeError, match="No results"):
agg.build_synthesis_request(
"unknown", "question", "user", "default",
)
class TestCleanupStale:
def test_removes_entries_older_than_timeout(self):
agg = Aggregator(timeout=1)
agg.register_fanout("corr-1", "parent-1", 2)
# Backdate the creation time
agg.correlations["corr-1"]["created_at"] = time.time() - 2
stale = agg.cleanup_stale()
assert "corr-1" in stale
assert "corr-1" not in agg.correlations
def test_keeps_recent_entries(self):
agg = Aggregator(timeout=300)
agg.register_fanout("corr-1", "parent-1", 2)
stale = agg.cleanup_stale()
assert stale == []
assert "corr-1" in agg.correlations
def test_mixed_stale_and_fresh(self):
agg = Aggregator(timeout=1)
agg.register_fanout("stale", "parent-1", 2)
agg.register_fanout("fresh", "parent-2", 2)
agg.correlations["stale"]["created_at"] = time.time() - 2
stale = agg.cleanup_stale()
assert "stale" in stale
assert "stale" not in agg.correlations
assert "fresh" in agg.correlations

View file

@ -0,0 +1,174 @@
"""
Unit tests for completion dispatch verifies that agent_request() in the
orchestrator service correctly intercepts subagent completion messages and
routes them to _handle_subagent_completion.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
return AgentRequest(**defaults)
def _make_completion_request(correlation_id, goal, answer):
"""Build a completion request as emit_subagent_completion would."""
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation=answer,
step_type="subagent-completion",
)
return _make_request(
correlation_id=correlation_id,
parent_session_id="parent-sess",
subagent_goal=goal,
expected_siblings=2,
history=[step],
)
class TestCompletionDetection:
"""Test that completion messages are correctly identified."""
def test_is_completion_when_correlation_id_and_step_type(self):
req = _make_completion_request("corr-1", "goal-a", "answer-a")
has_correlation = bool(getattr(req, 'correlation_id', ''))
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in req.history
)
assert has_correlation
assert is_completion
def test_not_completion_without_correlation_id(self):
step = AgentStep(
step_type="subagent-completion",
observation="answer",
)
req = _make_request(
correlation_id="",
history=[step],
)
has_correlation = bool(getattr(req, 'correlation_id', ''))
assert not has_correlation
def test_not_completion_without_step_type(self):
step = AgentStep(
step_type="react",
observation="answer",
)
req = _make_request(
correlation_id="corr-1",
history=[step],
)
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in req.history
)
assert not is_completion
def test_not_completion_with_empty_history(self):
req = _make_request(
correlation_id="corr-1",
history=[],
)
assert not req.history
class TestAggregatorIntegration:
"""Test the aggregator flow as used by _handle_subagent_completion."""
def test_full_completion_flow(self):
"""Simulates the flow: register, record completions, build synthesis."""
agg = Aggregator()
template = _make_request(
question="Original question",
streaming=True,
task_type="risk-assessment",
framing="Assess risks",
session_id="parent-sess",
)
# Register fan-out
agg.register_fanout("corr-1", "parent-sess", 2,
request_template=template)
# First completion — not all done
all_done = agg.record_completion(
"corr-1", "goal-a", "answer-a",
)
assert all_done is False
# Second completion — all done
all_done = agg.record_completion(
"corr-1", "goal-b", "answer-b",
)
assert all_done is True
# Peek at template
peeked = agg.get_original_request("corr-1")
assert peeked.question == "Original question"
# Build synthesis request
synth = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
# Verify synthesis request
assert synth.pattern == "supervisor"
assert synth.correlation_id == ""
assert synth.session_id == "parent-sess"
assert synth.streaming is True
# Verify synthesis history has results
synth_steps = [
s for s in synth.history
if getattr(s, 'step_type', '') == 'synthesise'
]
assert len(synth_steps) == 1
assert synth_steps[0].subagent_results == {
"goal-a": "answer-a",
"goal-b": "answer-b",
}
def test_synthesis_request_not_detected_as_completion(self):
"""The synthesis request must not be intercepted as a completion."""
agg = Aggregator()
template = _make_request(session_id="parent-sess")
agg.register_fanout("corr-1", "parent-sess", 1,
request_template=template)
agg.record_completion("corr-1", "goal", "answer")
synth = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# correlation_id must be empty so it's not intercepted
assert synth.correlation_id == ""
# Even if we check for completion step, shouldn't match
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in synth.history
)
assert not is_completion

View file

@ -0,0 +1,162 @@
"""
Unit tests for explainability API parsing verifies that from_triples()
correctly dispatches and parses the new orchestrator entity types.
"""
import pytest
from trustgraph.api.explainability import (
ExplainEntity,
Decomposition,
Finding,
Plan,
StepResult,
Synthesis,
Analysis,
Conclusion,
TG_DECOMPOSITION,
TG_FINDING,
TG_PLAN_TYPE,
TG_STEP_RESULT,
TG_SYNTHESIS,
TG_ANSWER_TYPE,
TG_ANALYSIS,
TG_CONCLUSION,
TG_DOCUMENT,
TG_SUBAGENT_GOAL,
TG_PLAN_STEP,
RDF_TYPE,
)
PROV_ENTITY = "http://www.w3.org/ns/prov#Entity"
def _make_triples(uri, types, extras=None):
"""Build a list of (s, p, o) tuples for testing."""
triples = [(uri, RDF_TYPE, t) for t in types]
if extras:
triples.extend((uri, p, o) for p, o in extras)
return triples
class TestFromTriplesDispatch:
def test_dispatches_decomposition(self):
triples = _make_triples("urn:d", [PROV_ENTITY, TG_DECOMPOSITION])
entity = ExplainEntity.from_triples("urn:d", triples)
assert isinstance(entity, Decomposition)
def test_dispatches_finding(self):
triples = _make_triples("urn:f",
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:f", triples)
assert isinstance(entity, Finding)
def test_dispatches_plan(self):
triples = _make_triples("urn:p", [PROV_ENTITY, TG_PLAN_TYPE])
entity = ExplainEntity.from_triples("urn:p", triples)
assert isinstance(entity, Plan)
def test_dispatches_step_result(self):
triples = _make_triples("urn:sr",
[PROV_ENTITY, TG_STEP_RESULT, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:sr", triples)
assert isinstance(entity, StepResult)
def test_dispatches_synthesis(self):
triples = _make_triples("urn:s",
[PROV_ENTITY, TG_SYNTHESIS, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:s", triples)
assert isinstance(entity, Synthesis)
def test_dispatches_analysis_unchanged(self):
triples = _make_triples("urn:a", [PROV_ENTITY, TG_ANALYSIS])
entity = ExplainEntity.from_triples("urn:a", triples)
assert isinstance(entity, Analysis)
def test_dispatches_conclusion_unchanged(self):
triples = _make_triples("urn:c",
[PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:c", triples)
assert isinstance(entity, Conclusion)
def test_finding_takes_precedence_over_synthesis(self):
"""Finding has Answer mixin but should dispatch to Finding, not
Synthesis, because Finding is checked first."""
triples = _make_triples("urn:f",
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:f", triples)
assert isinstance(entity, Finding)
assert not isinstance(entity, Synthesis)
class TestDecompositionParsing:
def test_parses_goals(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION], [
(TG_SUBAGENT_GOAL, "What is X?"),
(TG_SUBAGENT_GOAL, "What is Y?"),
])
entity = Decomposition.from_triples("urn:d", triples)
assert set(entity.goals) == {"What is X?", "What is Y?"}
def test_entity_type_field(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
entity = Decomposition.from_triples("urn:d", triples)
assert entity.entity_type == "decomposition"
def test_empty_goals(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
entity = Decomposition.from_triples("urn:d", triples)
assert entity.goals == []
class TestFindingParsing:
def test_parses_goal_and_document(self):
triples = _make_triples("urn:f", [TG_FINDING, TG_ANSWER_TYPE], [
(TG_SUBAGENT_GOAL, "What is X?"),
(TG_DOCUMENT, "urn:doc/finding"),
])
entity = Finding.from_triples("urn:f", triples)
assert entity.goal == "What is X?"
assert entity.document == "urn:doc/finding"
def test_entity_type_field(self):
triples = _make_triples("urn:f", [TG_FINDING])
entity = Finding.from_triples("urn:f", triples)
assert entity.entity_type == "finding"
class TestPlanParsing:
def test_parses_steps(self):
triples = _make_triples("urn:p", [TG_PLAN_TYPE], [
(TG_PLAN_STEP, "Define X"),
(TG_PLAN_STEP, "Research Y"),
(TG_PLAN_STEP, "Analyse Z"),
])
entity = Plan.from_triples("urn:p", triples)
assert set(entity.steps) == {"Define X", "Research Y", "Analyse Z"}
def test_entity_type_field(self):
triples = _make_triples("urn:p", [TG_PLAN_TYPE])
entity = Plan.from_triples("urn:p", triples)
assert entity.entity_type == "plan"
class TestStepResultParsing:
def test_parses_step_and_document(self):
triples = _make_triples("urn:sr", [TG_STEP_RESULT, TG_ANSWER_TYPE], [
(TG_PLAN_STEP, "Define X"),
(TG_DOCUMENT, "urn:doc/step"),
])
entity = StepResult.from_triples("urn:sr", triples)
assert entity.step == "Define X"
assert entity.document == "urn:doc/step"
def test_entity_type_field(self):
triples = _make_triples("urn:sr", [TG_STEP_RESULT])
entity = StepResult.from_triples("urn:sr", triples)
assert entity.entity_type == "step-result"

View file

@ -0,0 +1,289 @@
"""
Unit tests for the MetaRouter task type identification and pattern selection.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
def _make_config(patterns=None, task_types=None):
"""Build a config dict as the config service would provide."""
config = {}
if patterns:
config["agent-pattern"] = {
pid: json.dumps(pdata) for pid, pdata in patterns.items()
}
if task_types:
config["agent-task-type"] = {
tid: json.dumps(tdata) for tid, tdata in task_types.items()
}
return config
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response)
def context(service_name):
return client
return context
SAMPLE_PATTERNS = {
"react": {"name": "react", "description": "ReAct pattern"},
"plan-then-execute": {"name": "plan-then-execute", "description": "Plan pattern"},
"supervisor": {"name": "supervisor", "description": "Supervisor pattern"},
}
SAMPLE_TASK_TYPES = {
"general": {
"name": "general",
"description": "General queries",
"valid_patterns": ["react", "plan-then-execute", "supervisor"],
"framing": "",
},
"research": {
"name": "research",
"description": "Research queries",
"valid_patterns": ["react", "plan-then-execute"],
"framing": "Focus on gathering information.",
},
"summarisation": {
"name": "summarisation",
"description": "Summarisation queries",
"valid_patterns": ["react"],
"framing": "Focus on concise synthesis.",
},
}
class TestMetaRouterInit:
def test_defaults_when_no_config(self):
router = MetaRouter()
assert "react" in router.patterns
assert "general" in router.task_types
def test_loads_patterns_from_config(self):
config = _make_config(patterns=SAMPLE_PATTERNS)
router = MetaRouter(config=config)
assert set(router.patterns.keys()) == {"react", "plan-then-execute", "supervisor"}
def test_loads_task_types_from_config(self):
config = _make_config(task_types=SAMPLE_TASK_TYPES)
router = MetaRouter(config=config)
assert set(router.task_types.keys()) == {"general", "research", "summarisation"}
def test_handles_invalid_json_in_config(self):
config = {
"agent-pattern": {"react": "not valid json"},
}
router = MetaRouter(config=config)
assert "react" in router.patterns
assert router.patterns["react"]["name"] == "react"
class TestIdentifyTaskType:
@pytest.mark.asyncio
async def test_skips_llm_when_single_task_type(self):
router = MetaRouter() # Only "general"
context = _make_context("should not be called")
task_type, framing = await router.identify_task_type(
"test question", context,
)
assert task_type == "general"
@pytest.mark.asyncio
async def test_uses_llm_when_multiple_task_types(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("research")
task_type, framing = await router.identify_task_type(
"Research the topic", context,
)
assert task_type == "research"
assert framing == "Focus on gathering information."
@pytest.mark.asyncio
async def test_handles_llm_returning_quoted_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context('"summarisation"')
task_type, _ = await router.identify_task_type(
"Summarise this", context,
)
assert task_type == "summarisation"
@pytest.mark.asyncio
async def test_falls_back_on_unknown_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("nonexistent-type")
task_type, _ = await router.identify_task_type(
"test question", context,
)
assert task_type == DEFAULT_TASK_TYPE
@pytest.mark.asyncio
async def test_falls_back_on_llm_error(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
client = AsyncMock()
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
context = lambda name: client
task_type, _ = await router.identify_task_type(
"test question", context,
)
assert task_type == DEFAULT_TASK_TYPE
class TestSelectPattern:
@pytest.mark.asyncio
async def test_skips_llm_when_single_valid_pattern(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("should not be called")
# summarisation only has ["react"]
pattern = await router.select_pattern(
"Summarise this", "summarisation", context,
)
assert pattern == "react"
@pytest.mark.asyncio
async def test_uses_llm_when_multiple_valid_patterns(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("plan-then-execute")
# research has ["react", "plan-then-execute"]
pattern = await router.select_pattern(
"Research this", "research", context,
)
assert pattern == "plan-then-execute"
@pytest.mark.asyncio
async def test_respects_valid_patterns_constraint(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
# LLM returns supervisor, but research doesn't allow it
context = _make_context("supervisor")
pattern = await router.select_pattern(
"Research this", "research", context,
)
# Should fall back to first valid pattern
assert pattern == "react"
@pytest.mark.asyncio
async def test_falls_back_on_llm_error(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
client = AsyncMock()
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
context = lambda name: client
# general has ["react", "plan-then-execute", "supervisor"]
pattern = await router.select_pattern(
"test", "general", context,
)
# Falls back to first valid pattern
assert pattern == "react"
@pytest.mark.asyncio
async def test_falls_back_to_default_for_unknown_task_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("react")
# Unknown task type — valid_patterns falls back to all patterns
pattern = await router.select_pattern(
"test", "unknown-type", context,
)
assert pattern == "react"
class TestRoute:
@pytest.mark.asyncio
async def test_full_routing_pipeline(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
# Mock context where prompt returns different values per call
client = AsyncMock()
call_count = 0
async def mock_prompt(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return "research" # task type
return "plan-then-execute" # pattern
client.prompt = mock_prompt
context = lambda name: client
pattern, task_type, framing = await router.route(
"Research the relationships", context,
)
assert task_type == "research"
assert pattern == "plan-then-execute"
assert framing == "Focus on gathering information."

View file

@ -0,0 +1,144 @@
"""
Unit tests for PatternBase subagent helpers is_subagent() and
emit_subagent_completion().
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from dataclasses import dataclass
from trustgraph.schema import AgentRequest
from trustgraph.agent.orchestrator.pattern_base import PatternBase
@dataclass
class MockProcessor:
"""Minimal processor mock for PatternBase."""
pass
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
return AgentRequest(**defaults)
def _make_pattern():
return PatternBase(MockProcessor())
class TestIsSubagent:
def test_returns_true_when_correlation_id_set(self):
pattern = _make_pattern()
request = _make_request(correlation_id="corr-123")
assert pattern.is_subagent(request) is True
def test_returns_false_when_correlation_id_empty(self):
pattern = _make_pattern()
request = _make_request(correlation_id="")
assert pattern.is_subagent(request) is False
def test_returns_false_when_correlation_id_missing(self):
pattern = _make_pattern()
request = _make_request()
assert pattern.is_subagent(request) is False
class TestEmitSubagentCompletion:
@pytest.mark.asyncio
async def test_calls_next_with_completion_request(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "The answer is Y",
)
next_fn.assert_called_once()
completion_req = next_fn.call_args[0][0]
assert isinstance(completion_req, AgentRequest)
@pytest.mark.asyncio
async def test_completion_has_correct_step_type(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="What is X?",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer text",
)
completion_req = next_fn.call_args[0][0]
assert len(completion_req.history) == 1
step = completion_req.history[0]
assert step.step_type == "subagent-completion"
@pytest.mark.asyncio
async def test_completion_carries_answer_in_observation(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="What is X?",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "The answer is Y",
)
completion_req = next_fn.call_args[0][0]
step = completion_req.history[0]
assert step.observation == "The answer is Y"
@pytest.mark.asyncio
async def test_completion_preserves_correlation_fields(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer",
)
completion_req = next_fn.call_args[0][0]
assert completion_req.correlation_id == "corr-123"
assert completion_req.parent_session_id == "parent-sess"
assert completion_req.subagent_goal == "What is X?"
assert completion_req.expected_siblings == 4
@pytest.mark.asyncio
async def test_completion_has_empty_pattern(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="goal",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer",
)
completion_req = next_fn.call_args[0][0]
assert completion_req.pattern == ""

View file

@ -0,0 +1,226 @@
"""
Unit tests for orchestrator provenance triple builders.
"""
import pytest
from trustgraph.provenance import (
agent_decomposition_triples,
agent_finding_triples,
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
)
def _triple_set(triples):
"""Convert triples to a set of (s_iri, p_iri, o_value) for easy assertion."""
result = set()
for t in triples:
s = t.s.iri
p = t.p.iri
o = t.o.iri if t.o.iri else t.o.value
result.add((s, p, o))
return result
def _has_type(triples, uri, rdf_type):
"""Check if a URI has a given rdf:type in the triples."""
return (uri, RDF_TYPE, rdf_type) in _triple_set(triples)
def _get_values(triples, uri, predicate):
"""Get all object values for a given subject + predicate."""
ts = _triple_set(triples)
return [o for s, p, o in ts if s == uri and p == predicate]
class TestDecompositionTriples:
def test_has_correct_types(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a", "goal-b"],
)
assert _has_type(triples, "urn:decompose", PROV_ENTITY)
assert _has_type(triples, "urn:decompose", TG_DECOMPOSITION)
def test_not_answer_type(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a"],
)
assert not _has_type(triples, "urn:decompose", TG_ANSWER_TYPE)
def test_links_to_session(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a"],
)
ts = _triple_set(triples)
assert ("urn:decompose", PROV_WAS_GENERATED_BY, "urn:session") in ts
def test_includes_goals(self):
goals = ["What is X?", "What is Y?", "What is Z?"]
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", goals,
)
values = _get_values(triples, "urn:decompose", TG_SUBAGENT_GOAL)
assert set(values) == set(goals)
def test_label_includes_count(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["a", "b", "c"],
)
labels = _get_values(triples, "urn:decompose", RDFS_LABEL)
assert any("3" in label for label in labels)
class TestFindingTriples:
def test_has_correct_types(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
assert _has_type(triples, "urn:finding", PROV_ENTITY)
assert _has_type(triples, "urn:finding", TG_FINDING)
assert _has_type(triples, "urn:finding", TG_ANSWER_TYPE)
def test_links_to_decomposition(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
ts = _triple_set(triples)
assert ("urn:finding", PROV_WAS_DERIVED_FROM, "urn:decompose") in ts
def test_includes_goal(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
values = _get_values(triples, "urn:finding", TG_SUBAGENT_GOAL)
assert "What is X?" in values
def test_includes_document_when_provided(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "goal",
document_id="urn:doc/1",
)
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
assert "urn:doc/1" in values
def test_no_document_when_none(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "goal",
)
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
assert values == []
class TestPlanTriples:
def test_has_correct_types(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
assert _has_type(triples, "urn:plan", PROV_ENTITY)
assert _has_type(triples, "urn:plan", TG_PLAN_TYPE)
def test_not_answer_type(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
assert not _has_type(triples, "urn:plan", TG_ANSWER_TYPE)
def test_links_to_session(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
ts = _triple_set(triples)
assert ("urn:plan", PROV_WAS_GENERATED_BY, "urn:session") in ts
def test_includes_steps(self):
steps = ["Define X", "Research Y", "Analyse Z"]
triples = agent_plan_triples(
"urn:plan", "urn:session", steps,
)
values = _get_values(triples, "urn:plan", TG_PLAN_STEP)
assert set(values) == set(steps)
def test_label_includes_count(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["a", "b"],
)
labels = _get_values(triples, "urn:plan", RDFS_LABEL)
assert any("2" in label for label in labels)
class TestStepResultTriples:
def test_has_correct_types(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
assert _has_type(triples, "urn:step", PROV_ENTITY)
assert _has_type(triples, "urn:step", TG_STEP_RESULT)
assert _has_type(triples, "urn:step", TG_ANSWER_TYPE)
def test_links_to_plan(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
ts = _triple_set(triples)
assert ("urn:step", PROV_WAS_DERIVED_FROM, "urn:plan") in ts
def test_includes_goal(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
values = _get_values(triples, "urn:step", TG_PLAN_STEP)
assert "Define X" in values
def test_includes_document_when_provided(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "goal",
document_id="urn:doc/step",
)
values = _get_values(triples, "urn:step", TG_DOCUMENT)
assert "urn:doc/step" in values
class TestSynthesisTriples:
def test_has_correct_types(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
)
assert _has_type(triples, "urn:synthesis", PROV_ENTITY)
assert _has_type(triples, "urn:synthesis", TG_SYNTHESIS)
assert _has_type(triples, "urn:synthesis", TG_ANSWER_TYPE)
def test_links_to_previous(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:last-finding",
)
ts = _triple_set(triples)
assert ("urn:synthesis", PROV_WAS_DERIVED_FROM,
"urn:last-finding") in ts
def test_includes_document_when_provided(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
document_id="urn:doc/synthesis",
)
values = _get_values(triples, "urn:synthesis", TG_DOCUMENT)
assert "urn:doc/synthesis" in values
def test_label_is_synthesis(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
)
labels = _get_values(triples, "urn:synthesis", RDFS_LABEL)
assert "Synthesis" in labels