mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Update tests for agent-orchestrator (#745)
Add 96 tests covering the orchestrator's aggregation, provenance, routing, and explainability parsing. These verify the supervisor fan-out/fan-in lifecycle, the new RDF provenance types (Decomposition, Finding, Plan, StepResult, Synthesis), and their round-trip through the wire format. Unit tests (84): - Aggregator: register, record completion, peek, build synthesis, cleanup - Provenance triple builders: types, provenance links, goals/steps, labels - Explainability parsing: from_triples dispatch, field extraction for all new entity types, precedence over existing types - PatternBase: is_subagent detection, emit_subagent_completion message shape - Completion dispatch: detection logic, full aggregator integration flow, synthesis request not re-intercepted as completion - MetaRouter: task type identification, pattern selection, valid_patterns constraints, fallback on LLM error or unknown response Contract tests (12): - Orchestration fields on AgentRequest round-trip correctly - subagent-completion and synthesise step types in request history - Plan steps with status and dependencies - Provenance triple builder → wire format → from_triples round-trip for all five new entity types
This commit is contained in:
parent
7b734148b3
commit
816a8cfcf6
8 changed files with 1517 additions and 0 deletions
177
tests/contract/test_orchestrator_contracts.py
Normal file
177
tests/contract/test_orchestrator_contracts.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Contract tests for orchestrator message schemas.
|
||||
|
||||
Verifies that AgentRequest/AgentStep with orchestration fields
|
||||
serialise and deserialise correctly through the Pulsar schema layer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep, PlanStep
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestOrchestrationFieldContracts:
|
||||
"""Contract tests for orchestration fields on AgentRequest."""
|
||||
|
||||
def test_agent_request_orchestration_fields_roundtrip(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
pattern="react",
|
||||
task_type="research",
|
||||
framing="Focus on accuracy",
|
||||
conversation_id="conv-456",
|
||||
)
|
||||
|
||||
assert req.correlation_id == "corr-123"
|
||||
assert req.parent_session_id == "parent-sess"
|
||||
assert req.subagent_goal == "What is X?"
|
||||
assert req.expected_siblings == 4
|
||||
assert req.pattern == "react"
|
||||
assert req.task_type == "research"
|
||||
assert req.framing == "Focus on accuracy"
|
||||
assert req.conversation_id == "conv-456"
|
||||
|
||||
def test_agent_request_orchestration_fields_default_empty(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
)
|
||||
|
||||
assert req.correlation_id == ""
|
||||
assert req.parent_session_id == ""
|
||||
assert req.subagent_goal == ""
|
||||
assert req.expected_siblings == 0
|
||||
assert req.pattern == ""
|
||||
assert req.task_type == ""
|
||||
assert req.framing == ""
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSubagentCompletionStepContract:
|
||||
"""Contract tests for subagent-completion step type."""
|
||||
|
||||
def test_subagent_completion_step_fields(self):
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation="The answer text",
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
|
||||
assert step.step_type == "subagent-completion"
|
||||
assert step.observation == "The answer text"
|
||||
assert step.thought == "Subagent completed"
|
||||
assert step.action == "complete"
|
||||
|
||||
def test_subagent_completion_in_request_history(self):
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation="answer",
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
req = AgentRequest(
|
||||
question="goal",
|
||||
user="testuser",
|
||||
correlation_id="corr-123",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
assert len(req.history) == 1
|
||||
assert req.history[0].step_type == "subagent-completion"
|
||||
assert req.history[0].observation == "answer"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSynthesisStepContract:
|
||||
"""Contract tests for synthesis step type with subagent_results."""
|
||||
|
||||
def test_synthesis_step_with_results(self):
|
||||
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
|
||||
step = AgentStep(
|
||||
thought="All subagents completed",
|
||||
action="aggregate",
|
||||
arguments={},
|
||||
observation=json.dumps(results),
|
||||
step_type="synthesise",
|
||||
subagent_results=results,
|
||||
)
|
||||
|
||||
assert step.step_type == "synthesise"
|
||||
assert step.subagent_results == results
|
||||
assert json.loads(step.observation) == results
|
||||
|
||||
def test_synthesis_request_matches_supervisor_expectations(self):
|
||||
"""The synthesis request built by the aggregator must be
|
||||
recognisable by SupervisorPattern._synthesise()."""
|
||||
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
|
||||
step = AgentStep(
|
||||
thought="All subagents completed",
|
||||
action="aggregate",
|
||||
arguments={},
|
||||
observation=json.dumps(results),
|
||||
step_type="synthesise",
|
||||
subagent_results=results,
|
||||
)
|
||||
|
||||
req = AgentRequest(
|
||||
question="Original question",
|
||||
user="testuser",
|
||||
pattern="supervisor",
|
||||
correlation_id="",
|
||||
session_id="parent-sess",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
# SupervisorPattern checks for step_type='synthesise' with
|
||||
# subagent_results
|
||||
has_results = bool(
|
||||
req.history
|
||||
and any(
|
||||
getattr(h, 'step_type', '') == 'synthesise'
|
||||
and getattr(h, 'subagent_results', None)
|
||||
for h in req.history
|
||||
)
|
||||
)
|
||||
assert has_results
|
||||
|
||||
# Pattern must be supervisor
|
||||
assert req.pattern == "supervisor"
|
||||
|
||||
# Correlation ID must be empty (not re-intercepted)
|
||||
assert req.correlation_id == ""
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestPlanStepContract:
|
||||
"""Contract tests for plan steps in history."""
|
||||
|
||||
def test_plan_step_in_history(self):
|
||||
plan = [
|
||||
PlanStep(goal="Step 1", tool_hint="knowledge-query",
|
||||
depends_on=[], status="completed", result="done"),
|
||||
PlanStep(goal="Step 2", tool_hint="",
|
||||
depends_on=[0], status="pending", result=""),
|
||||
]
|
||||
step = AgentStep(
|
||||
thought="Created plan",
|
||||
action="plan",
|
||||
step_type="plan",
|
||||
plan=plan,
|
||||
)
|
||||
|
||||
assert step.step_type == "plan"
|
||||
assert len(step.plan) == 2
|
||||
assert step.plan[0].goal == "Step 1"
|
||||
assert step.plan[0].status == "completed"
|
||||
assert step.plan[1].depends_on == [0]
|
||||
129
tests/contract/test_provenance_wire_format.py
Normal file
129
tests/contract/test_provenance_wire_format.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""
|
||||
Contract tests for provenance triple wire format — verifies that triples
|
||||
built by the provenance library can be parsed by the explainability API
|
||||
through the wire format conversion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
from trustgraph.provenance import (
|
||||
agent_decomposition_triples,
|
||||
agent_finding_triples,
|
||||
agent_plan_triples,
|
||||
agent_step_result_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.api.explainability import (
|
||||
ExplainEntity,
|
||||
Decomposition,
|
||||
Finding,
|
||||
Plan,
|
||||
StepResult,
|
||||
Synthesis,
|
||||
wire_triples_to_tuples,
|
||||
)
|
||||
|
||||
|
||||
def _triples_to_wire(triples):
|
||||
"""Convert provenance Triple objects to the wire format dicts
|
||||
that the gateway/socket client would produce."""
|
||||
wire = []
|
||||
for t in triples:
|
||||
entry = {
|
||||
"s": _term_to_wire(t.s),
|
||||
"p": _term_to_wire(t.p),
|
||||
"o": _term_to_wire(t.o),
|
||||
}
|
||||
wire.append(entry)
|
||||
return wire
|
||||
|
||||
|
||||
def _term_to_wire(term):
|
||||
"""Convert a Term to wire format dict."""
|
||||
if term.type == IRI:
|
||||
return {"t": "i", "i": term.iri}
|
||||
elif term.type == LITERAL:
|
||||
return {"t": "l", "v": term.value}
|
||||
return {"t": "l", "v": str(term)}
|
||||
|
||||
|
||||
def _roundtrip(triples, uri):
|
||||
"""Convert triples through wire format and parse via from_triples."""
|
||||
wire = _triples_to_wire(triples)
|
||||
tuples = wire_triples_to_tuples(wire)
|
||||
return ExplainEntity.from_triples(uri, tuples)
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestDecompositionWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session",
|
||||
["What is X?", "What is Y?"],
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:decompose")
|
||||
|
||||
assert isinstance(entity, Decomposition)
|
||||
assert set(entity.goals) == {"What is X?", "What is Y?"}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestFindingWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
document_id="urn:doc/finding",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:finding")
|
||||
|
||||
assert isinstance(entity, Finding)
|
||||
assert entity.goal == "What is X?"
|
||||
assert entity.document == "urn:doc/finding"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestPlanWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session",
|
||||
["Step 1", "Step 2", "Step 3"],
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:plan")
|
||||
|
||||
assert isinstance(entity, Plan)
|
||||
assert set(entity.steps) == {"Step 1", "Step 2", "Step 3"}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestStepResultWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
document_id="urn:doc/step",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:step")
|
||||
|
||||
assert isinstance(entity, StepResult)
|
||||
assert entity.step == "Define X"
|
||||
assert entity.document == "urn:doc/step"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSynthesisWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
document_id="urn:doc/synthesis",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:synthesis")
|
||||
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.document == "urn:doc/synthesis"
|
||||
216
tests/unit/test_agent/test_aggregator.py
Normal file
216
tests/unit/test_agent/test_aggregator.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""
|
||||
Unit tests for the Aggregator — tracks fan-out correlations and triggers
|
||||
synthesis when all subagents complete.
|
||||
"""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep
|
||||
|
||||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(question="Test question", user="testuser",
|
||||
collection="default", streaming=False,
|
||||
session_id="parent-session", task_type="research",
|
||||
framing="test framing", conversation_id="conv-1"):
|
||||
return AgentRequest(
|
||||
question=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
framing=framing,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
|
||||
class TestRegisterFanout:
|
||||
|
||||
def test_stores_correlation_entry(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 3)
|
||||
|
||||
assert "corr-1" in agg.correlations
|
||||
entry = agg.correlations["corr-1"]
|
||||
assert entry["parent_session_id"] == "parent-1"
|
||||
assert entry["expected"] == 3
|
||||
assert entry["results"] == {}
|
||||
|
||||
def test_stores_request_template(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
|
||||
entry = agg.correlations["corr-1"]
|
||||
assert entry["request_template"] is template
|
||||
|
||||
def test_records_creation_time(self):
|
||||
agg = Aggregator()
|
||||
before = time.time()
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
after = time.time()
|
||||
|
||||
created = agg.correlations["corr-1"]["created_at"]
|
||||
assert before <= created <= after
|
||||
|
||||
|
||||
class TestRecordCompletion:
|
||||
|
||||
def test_returns_false_until_all_done(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 3)
|
||||
|
||||
assert agg.record_completion("corr-1", "goal-a", "answer-a") is False
|
||||
assert agg.record_completion("corr-1", "goal-b", "answer-b") is False
|
||||
assert agg.record_completion("corr-1", "goal-c", "answer-c") is True
|
||||
|
||||
def test_returns_none_for_unknown_correlation(self):
|
||||
agg = Aggregator()
|
||||
result = agg.record_completion("unknown", "goal", "answer")
|
||||
assert result is None
|
||||
|
||||
def test_stores_results_by_goal(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
results = agg.correlations["corr-1"]["results"]
|
||||
assert results["goal-a"] == "answer-a"
|
||||
assert results["goal-b"] == "answer-b"
|
||||
|
||||
def test_single_subagent(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 1)
|
||||
|
||||
assert agg.record_completion("corr-1", "goal-a", "answer") is True
|
||||
|
||||
|
||||
class TestGetOriginalRequest:
|
||||
|
||||
def test_peeks_without_consuming(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
|
||||
result = agg.get_original_request("corr-1")
|
||||
assert result is template
|
||||
# Entry still exists
|
||||
assert "corr-1" in agg.correlations
|
||||
|
||||
def test_returns_none_for_unknown(self):
|
||||
agg = Aggregator()
|
||||
assert agg.get_original_request("unknown") is None
|
||||
|
||||
|
||||
class TestBuildSynthesisRequest:
|
||||
|
||||
def test_builds_correct_request(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request(
|
||||
question="Original question",
|
||||
streaming=True,
|
||||
task_type="risk-assessment",
|
||||
framing="Assess risks",
|
||||
)
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
assert req.question == "Original question"
|
||||
assert req.pattern == "supervisor"
|
||||
assert req.session_id == "parent-1"
|
||||
assert req.correlation_id == "" # Must be empty
|
||||
assert req.streaming == True
|
||||
assert req.task_type == "risk-assessment"
|
||||
assert req.framing == "Assess risks"
|
||||
|
||||
def test_synthesis_step_in_history(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# Last history step should be the synthesis step
|
||||
assert len(req.history) >= 1
|
||||
synth_step = req.history[-1]
|
||||
assert synth_step.step_type == "synthesise"
|
||||
assert synth_step.subagent_results == {
|
||||
"goal-a": "answer-a",
|
||||
"goal-b": "answer-b",
|
||||
}
|
||||
|
||||
def test_consumes_correlation_entry(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 1,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
|
||||
agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# Entry should be removed
|
||||
assert "corr-1" not in agg.correlations
|
||||
|
||||
def test_raises_for_unknown_correlation(self):
|
||||
agg = Aggregator()
|
||||
with pytest.raises(RuntimeError, match="No results"):
|
||||
agg.build_synthesis_request(
|
||||
"unknown", "question", "user", "default",
|
||||
)
|
||||
|
||||
|
||||
class TestCleanupStale:
|
||||
|
||||
def test_removes_entries_older_than_timeout(self):
|
||||
agg = Aggregator(timeout=1)
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
# Backdate the creation time
|
||||
agg.correlations["corr-1"]["created_at"] = time.time() - 2
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert "corr-1" in stale
|
||||
assert "corr-1" not in agg.correlations
|
||||
|
||||
def test_keeps_recent_entries(self):
|
||||
agg = Aggregator(timeout=300)
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert stale == []
|
||||
assert "corr-1" in agg.correlations
|
||||
|
||||
def test_mixed_stale_and_fresh(self):
|
||||
agg = Aggregator(timeout=1)
|
||||
agg.register_fanout("stale", "parent-1", 2)
|
||||
agg.register_fanout("fresh", "parent-2", 2)
|
||||
|
||||
agg.correlations["stale"]["created_at"] = time.time() - 2
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert "stale" in stale
|
||||
assert "stale" not in agg.correlations
|
||||
assert "fresh" in agg.correlations
|
||||
174
tests/unit/test_agent/test_completion_dispatch.py
Normal file
174
tests/unit/test_agent/test_completion_dispatch.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
Unit tests for completion dispatch — verifies that agent_request() in the
|
||||
orchestrator service correctly intercepts subagent completion messages and
|
||||
routes them to _handle_subagent_completion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep
|
||||
|
||||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return AgentRequest(**defaults)
|
||||
|
||||
|
||||
def _make_completion_request(correlation_id, goal, answer):
|
||||
"""Build a completion request as emit_subagent_completion would."""
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation=answer,
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
return _make_request(
|
||||
correlation_id=correlation_id,
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal=goal,
|
||||
expected_siblings=2,
|
||||
history=[step],
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionDetection:
|
||||
"""Test that completion messages are correctly identified."""
|
||||
|
||||
def test_is_completion_when_correlation_id_and_step_type(self):
|
||||
req = _make_completion_request("corr-1", "goal-a", "answer-a")
|
||||
|
||||
has_correlation = bool(getattr(req, 'correlation_id', ''))
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in req.history
|
||||
)
|
||||
|
||||
assert has_correlation
|
||||
assert is_completion
|
||||
|
||||
def test_not_completion_without_correlation_id(self):
|
||||
step = AgentStep(
|
||||
step_type="subagent-completion",
|
||||
observation="answer",
|
||||
)
|
||||
req = _make_request(
|
||||
correlation_id="",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
has_correlation = bool(getattr(req, 'correlation_id', ''))
|
||||
assert not has_correlation
|
||||
|
||||
def test_not_completion_without_step_type(self):
|
||||
step = AgentStep(
|
||||
step_type="react",
|
||||
observation="answer",
|
||||
)
|
||||
req = _make_request(
|
||||
correlation_id="corr-1",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in req.history
|
||||
)
|
||||
assert not is_completion
|
||||
|
||||
def test_not_completion_with_empty_history(self):
|
||||
req = _make_request(
|
||||
correlation_id="corr-1",
|
||||
history=[],
|
||||
)
|
||||
assert not req.history
|
||||
|
||||
|
||||
class TestAggregatorIntegration:
|
||||
"""Test the aggregator flow as used by _handle_subagent_completion."""
|
||||
|
||||
def test_full_completion_flow(self):
|
||||
"""Simulates the flow: register, record completions, build synthesis."""
|
||||
agg = Aggregator()
|
||||
template = _make_request(
|
||||
question="Original question",
|
||||
streaming=True,
|
||||
task_type="risk-assessment",
|
||||
framing="Assess risks",
|
||||
session_id="parent-sess",
|
||||
)
|
||||
|
||||
# Register fan-out
|
||||
agg.register_fanout("corr-1", "parent-sess", 2,
|
||||
request_template=template)
|
||||
|
||||
# First completion — not all done
|
||||
all_done = agg.record_completion(
|
||||
"corr-1", "goal-a", "answer-a",
|
||||
)
|
||||
assert all_done is False
|
||||
|
||||
# Second completion — all done
|
||||
all_done = agg.record_completion(
|
||||
"corr-1", "goal-b", "answer-b",
|
||||
)
|
||||
assert all_done is True
|
||||
|
||||
# Peek at template
|
||||
peeked = agg.get_original_request("corr-1")
|
||||
assert peeked.question == "Original question"
|
||||
|
||||
# Build synthesis request
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
# Verify synthesis request
|
||||
assert synth.pattern == "supervisor"
|
||||
assert synth.correlation_id == ""
|
||||
assert synth.session_id == "parent-sess"
|
||||
assert synth.streaming is True
|
||||
|
||||
# Verify synthesis history has results
|
||||
synth_steps = [
|
||||
s for s in synth.history
|
||||
if getattr(s, 'step_type', '') == 'synthesise'
|
||||
]
|
||||
assert len(synth_steps) == 1
|
||||
assert synth_steps[0].subagent_results == {
|
||||
"goal-a": "answer-a",
|
||||
"goal-b": "answer-b",
|
||||
}
|
||||
|
||||
def test_synthesis_request_not_detected_as_completion(self):
|
||||
"""The synthesis request must not be intercepted as a completion."""
|
||||
agg = Aggregator()
|
||||
template = _make_request(session_id="parent-sess")
|
||||
agg.register_fanout("corr-1", "parent-sess", 1,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal", "answer")
|
||||
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# correlation_id must be empty so it's not intercepted
|
||||
assert synth.correlation_id == ""
|
||||
|
||||
# Even if we check for completion step, shouldn't match
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in synth.history
|
||||
)
|
||||
assert not is_completion
|
||||
162
tests/unit/test_agent/test_explainability_parsing.py
Normal file
162
tests/unit/test_agent/test_explainability_parsing.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
Unit tests for explainability API parsing — verifies that from_triples()
|
||||
correctly dispatches and parses the new orchestrator entity types.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.api.explainability import (
|
||||
ExplainEntity,
|
||||
Decomposition,
|
||||
Finding,
|
||||
Plan,
|
||||
StepResult,
|
||||
Synthesis,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
TG_DECOMPOSITION,
|
||||
TG_FINDING,
|
||||
TG_PLAN_TYPE,
|
||||
TG_STEP_RESULT,
|
||||
TG_SYNTHESIS,
|
||||
TG_ANSWER_TYPE,
|
||||
TG_ANALYSIS,
|
||||
TG_CONCLUSION,
|
||||
TG_DOCUMENT,
|
||||
TG_SUBAGENT_GOAL,
|
||||
TG_PLAN_STEP,
|
||||
RDF_TYPE,
|
||||
)
|
||||
|
||||
PROV_ENTITY = "http://www.w3.org/ns/prov#Entity"
|
||||
|
||||
|
||||
def _make_triples(uri, types, extras=None):
|
||||
"""Build a list of (s, p, o) tuples for testing."""
|
||||
triples = [(uri, RDF_TYPE, t) for t in types]
|
||||
if extras:
|
||||
triples.extend((uri, p, o) for p, o in extras)
|
||||
return triples
|
||||
|
||||
|
||||
class TestFromTriplesDispatch:
|
||||
|
||||
def test_dispatches_decomposition(self):
|
||||
triples = _make_triples("urn:d", [PROV_ENTITY, TG_DECOMPOSITION])
|
||||
entity = ExplainEntity.from_triples("urn:d", triples)
|
||||
assert isinstance(entity, Decomposition)
|
||||
|
||||
def test_dispatches_finding(self):
|
||||
triples = _make_triples("urn:f",
|
||||
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:f", triples)
|
||||
assert isinstance(entity, Finding)
|
||||
|
||||
def test_dispatches_plan(self):
|
||||
triples = _make_triples("urn:p", [PROV_ENTITY, TG_PLAN_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:p", triples)
|
||||
assert isinstance(entity, Plan)
|
||||
|
||||
def test_dispatches_step_result(self):
|
||||
triples = _make_triples("urn:sr",
|
||||
[PROV_ENTITY, TG_STEP_RESULT, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:sr", triples)
|
||||
assert isinstance(entity, StepResult)
|
||||
|
||||
def test_dispatches_synthesis(self):
|
||||
triples = _make_triples("urn:s",
|
||||
[PROV_ENTITY, TG_SYNTHESIS, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:s", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
|
||||
def test_dispatches_analysis_unchanged(self):
|
||||
triples = _make_triples("urn:a", [PROV_ENTITY, TG_ANALYSIS])
|
||||
entity = ExplainEntity.from_triples("urn:a", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
|
||||
def test_dispatches_conclusion_unchanged(self):
|
||||
triples = _make_triples("urn:c",
|
||||
[PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:c", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
|
||||
def test_finding_takes_precedence_over_synthesis(self):
|
||||
"""Finding has Answer mixin but should dispatch to Finding, not
|
||||
Synthesis, because Finding is checked first."""
|
||||
triples = _make_triples("urn:f",
|
||||
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:f", triples)
|
||||
assert isinstance(entity, Finding)
|
||||
assert not isinstance(entity, Synthesis)
|
||||
|
||||
|
||||
class TestDecompositionParsing:
|
||||
|
||||
def test_parses_goals(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION], [
|
||||
(TG_SUBAGENT_GOAL, "What is X?"),
|
||||
(TG_SUBAGENT_GOAL, "What is Y?"),
|
||||
])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert set(entity.goals) == {"What is X?", "What is Y?"}
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert entity.entity_type == "decomposition"
|
||||
|
||||
def test_empty_goals(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert entity.goals == []
|
||||
|
||||
|
||||
class TestFindingParsing:
|
||||
|
||||
def test_parses_goal_and_document(self):
|
||||
triples = _make_triples("urn:f", [TG_FINDING, TG_ANSWER_TYPE], [
|
||||
(TG_SUBAGENT_GOAL, "What is X?"),
|
||||
(TG_DOCUMENT, "urn:doc/finding"),
|
||||
])
|
||||
entity = Finding.from_triples("urn:f", triples)
|
||||
assert entity.goal == "What is X?"
|
||||
assert entity.document == "urn:doc/finding"
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:f", [TG_FINDING])
|
||||
entity = Finding.from_triples("urn:f", triples)
|
||||
assert entity.entity_type == "finding"
|
||||
|
||||
|
||||
class TestPlanParsing:
|
||||
|
||||
def test_parses_steps(self):
|
||||
triples = _make_triples("urn:p", [TG_PLAN_TYPE], [
|
||||
(TG_PLAN_STEP, "Define X"),
|
||||
(TG_PLAN_STEP, "Research Y"),
|
||||
(TG_PLAN_STEP, "Analyse Z"),
|
||||
])
|
||||
entity = Plan.from_triples("urn:p", triples)
|
||||
assert set(entity.steps) == {"Define X", "Research Y", "Analyse Z"}
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:p", [TG_PLAN_TYPE])
|
||||
entity = Plan.from_triples("urn:p", triples)
|
||||
assert entity.entity_type == "plan"
|
||||
|
||||
|
||||
class TestStepResultParsing:
|
||||
|
||||
def test_parses_step_and_document(self):
|
||||
triples = _make_triples("urn:sr", [TG_STEP_RESULT, TG_ANSWER_TYPE], [
|
||||
(TG_PLAN_STEP, "Define X"),
|
||||
(TG_DOCUMENT, "urn:doc/step"),
|
||||
])
|
||||
entity = StepResult.from_triples("urn:sr", triples)
|
||||
assert entity.step == "Define X"
|
||||
assert entity.document == "urn:doc/step"
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:sr", [TG_STEP_RESULT])
|
||||
entity = StepResult.from_triples("urn:sr", triples)
|
||||
assert entity.entity_type == "step-result"
|
||||
289
tests/unit/test_agent/test_meta_router.py
Normal file
289
tests/unit/test_agent/test_meta_router.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""
|
||||
Unit tests for the MetaRouter — task type identification and pattern selection.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.orchestrator.meta_router import (
|
||||
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
|
||||
)
|
||||
|
||||
|
||||
def _make_config(patterns=None, task_types=None):
|
||||
"""Build a config dict as the config service would provide."""
|
||||
config = {}
|
||||
if patterns:
|
||||
config["agent-pattern"] = {
|
||||
pid: json.dumps(pdata) for pid, pdata in patterns.items()
|
||||
}
|
||||
if task_types:
|
||||
config["agent-task-type"] = {
|
||||
tid: json.dumps(tdata) for tid, tdata in task_types.items()
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_context(prompt_response):
|
||||
"""Build a mock context that returns a mock prompt client."""
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(return_value=prompt_response)
|
||||
|
||||
def context(service_name):
|
||||
return client
|
||||
|
||||
return context
|
||||
|
||||
|
||||
SAMPLE_PATTERNS = {
|
||||
"react": {"name": "react", "description": "ReAct pattern"},
|
||||
"plan-then-execute": {"name": "plan-then-execute", "description": "Plan pattern"},
|
||||
"supervisor": {"name": "supervisor", "description": "Supervisor pattern"},
|
||||
}
|
||||
|
||||
SAMPLE_TASK_TYPES = {
|
||||
"general": {
|
||||
"name": "general",
|
||||
"description": "General queries",
|
||||
"valid_patterns": ["react", "plan-then-execute", "supervisor"],
|
||||
"framing": "",
|
||||
},
|
||||
"research": {
|
||||
"name": "research",
|
||||
"description": "Research queries",
|
||||
"valid_patterns": ["react", "plan-then-execute"],
|
||||
"framing": "Focus on gathering information.",
|
||||
},
|
||||
"summarisation": {
|
||||
"name": "summarisation",
|
||||
"description": "Summarisation queries",
|
||||
"valid_patterns": ["react"],
|
||||
"framing": "Focus on concise synthesis.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestMetaRouterInit:
|
||||
|
||||
def test_defaults_when_no_config(self):
|
||||
router = MetaRouter()
|
||||
assert "react" in router.patterns
|
||||
assert "general" in router.task_types
|
||||
|
||||
def test_loads_patterns_from_config(self):
|
||||
config = _make_config(patterns=SAMPLE_PATTERNS)
|
||||
router = MetaRouter(config=config)
|
||||
assert set(router.patterns.keys()) == {"react", "plan-then-execute", "supervisor"}
|
||||
|
||||
def test_loads_task_types_from_config(self):
|
||||
config = _make_config(task_types=SAMPLE_TASK_TYPES)
|
||||
router = MetaRouter(config=config)
|
||||
assert set(router.task_types.keys()) == {"general", "research", "summarisation"}
|
||||
|
||||
def test_handles_invalid_json_in_config(self):
|
||||
config = {
|
||||
"agent-pattern": {"react": "not valid json"},
|
||||
}
|
||||
router = MetaRouter(config=config)
|
||||
assert "react" in router.patterns
|
||||
assert router.patterns["react"]["name"] == "react"
|
||||
|
||||
|
||||
class TestIdentifyTaskType:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_llm_when_single_task_type(self):
|
||||
router = MetaRouter() # Only "general"
|
||||
context = _make_context("should not be called")
|
||||
|
||||
task_type, framing = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == "general"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_llm_when_multiple_task_types(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("research")
|
||||
|
||||
task_type, framing = await router.identify_task_type(
|
||||
"Research the topic", context,
|
||||
)
|
||||
|
||||
assert task_type == "research"
|
||||
assert framing == "Focus on gathering information."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_llm_returning_quoted_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context('"summarisation"')
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"Summarise this", context,
|
||||
)
|
||||
|
||||
assert task_type == "summarisation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_unknown_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("nonexistent-type")
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == DEFAULT_TASK_TYPE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_llm_error(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
context = lambda name: client
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == DEFAULT_TASK_TYPE
|
||||
|
||||
|
||||
class TestSelectPattern:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_llm_when_single_valid_pattern(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("should not be called")
|
||||
|
||||
# summarisation only has ["react"]
|
||||
pattern = await router.select_pattern(
|
||||
"Summarise this", "summarisation", context,
|
||||
)
|
||||
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_llm_when_multiple_valid_patterns(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("plan-then-execute")
|
||||
|
||||
# research has ["react", "plan-then-execute"]
|
||||
pattern = await router.select_pattern(
|
||||
"Research this", "research", context,
|
||||
)
|
||||
|
||||
assert pattern == "plan-then-execute"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_valid_patterns_constraint(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
# LLM returns supervisor, but research doesn't allow it
|
||||
context = _make_context("supervisor")
|
||||
|
||||
pattern = await router.select_pattern(
|
||||
"Research this", "research", context,
|
||||
)
|
||||
|
||||
# Should fall back to first valid pattern
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_llm_error(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
context = lambda name: client
|
||||
|
||||
# general has ["react", "plan-then-execute", "supervisor"]
|
||||
pattern = await router.select_pattern(
|
||||
"test", "general", context,
|
||||
)
|
||||
|
||||
# Falls back to first valid pattern
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_default_for_unknown_task_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("react")
|
||||
|
||||
# Unknown task type — valid_patterns falls back to all patterns
|
||||
pattern = await router.select_pattern(
|
||||
"test", "unknown-type", context,
|
||||
)
|
||||
|
||||
assert pattern == "react"
|
||||
|
||||
|
||||
class TestRoute:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_routing_pipeline(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
# Mock context where prompt returns different values per call
|
||||
client = AsyncMock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_prompt(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "research" # task type
|
||||
return "plan-then-execute" # pattern
|
||||
|
||||
client.prompt = mock_prompt
|
||||
context = lambda name: client
|
||||
|
||||
pattern, task_type, framing = await router.route(
|
||||
"Research the relationships", context,
|
||||
)
|
||||
|
||||
assert task_type == "research"
|
||||
assert pattern == "plan-then-execute"
|
||||
assert framing == "Focus on gathering information."
|
||||
144
tests/unit/test_agent/test_pattern_base_subagent.py
Normal file
144
tests/unit/test_agent/test_pattern_base_subagent.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
Unit tests for PatternBase subagent helpers — is_subagent() and
|
||||
emit_subagent_completion().
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.schema import AgentRequest
|
||||
|
||||
from trustgraph.agent.orchestrator.pattern_base import PatternBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockProcessor:
|
||||
"""Minimal processor mock for PatternBase."""
|
||||
pass
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return AgentRequest(**defaults)
|
||||
|
||||
|
||||
def _make_pattern():
|
||||
return PatternBase(MockProcessor())
|
||||
|
||||
|
||||
class TestIsSubagent:
|
||||
|
||||
def test_returns_true_when_correlation_id_set(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(correlation_id="corr-123")
|
||||
assert pattern.is_subagent(request) is True
|
||||
|
||||
def test_returns_false_when_correlation_id_empty(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(correlation_id="")
|
||||
assert pattern.is_subagent(request) is False
|
||||
|
||||
def test_returns_false_when_correlation_id_missing(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request()
|
||||
assert pattern.is_subagent(request) is False
|
||||
|
||||
|
||||
class TestEmitSubagentCompletion:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_next_with_completion_request(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "The answer is Y",
|
||||
)
|
||||
|
||||
next_fn.assert_called_once()
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert isinstance(completion_req, AgentRequest)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_has_correct_step_type(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="What is X?",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer text",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert len(completion_req.history) == 1
|
||||
step = completion_req.history[0]
|
||||
assert step.step_type == "subagent-completion"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_carries_answer_in_observation(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="What is X?",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "The answer is Y",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
step = completion_req.history[0]
|
||||
assert step.observation == "The answer is Y"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_preserves_correlation_fields(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert completion_req.correlation_id == "corr-123"
|
||||
assert completion_req.parent_session_id == "parent-sess"
|
||||
assert completion_req.subagent_goal == "What is X?"
|
||||
assert completion_req.expected_siblings == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_has_empty_pattern(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="goal",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert completion_req.pattern == ""
|
||||
226
tests/unit/test_agent/test_provenance_triples.py
Normal file
226
tests/unit/test_agent/test_provenance_triples.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
Unit tests for orchestrator provenance triple builders.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.provenance import (
|
||||
agent_decomposition_triples,
|
||||
agent_finding_triples,
|
||||
agent_plan_triples,
|
||||
agent_step_result_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
)
|
||||
|
||||
|
||||
def _triple_set(triples):
|
||||
"""Convert triples to a set of (s_iri, p_iri, o_value) for easy assertion."""
|
||||
result = set()
|
||||
for t in triples:
|
||||
s = t.s.iri
|
||||
p = t.p.iri
|
||||
o = t.o.iri if t.o.iri else t.o.value
|
||||
result.add((s, p, o))
|
||||
return result
|
||||
|
||||
|
||||
def _has_type(triples, uri, rdf_type):
|
||||
"""Check if a URI has a given rdf:type in the triples."""
|
||||
return (uri, RDF_TYPE, rdf_type) in _triple_set(triples)
|
||||
|
||||
|
||||
def _get_values(triples, uri, predicate):
|
||||
"""Get all object values for a given subject + predicate."""
|
||||
ts = _triple_set(triples)
|
||||
return [o for s, p, o in ts if s == uri and p == predicate]
|
||||
|
||||
|
||||
class TestDecompositionTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a", "goal-b"],
|
||||
)
|
||||
assert _has_type(triples, "urn:decompose", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:decompose", TG_DECOMPOSITION)
|
||||
|
||||
def test_not_answer_type(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a"],
|
||||
)
|
||||
assert not _has_type(triples, "urn:decompose", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_session(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a"],
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:decompose", PROV_WAS_GENERATED_BY, "urn:session") in ts
|
||||
|
||||
def test_includes_goals(self):
|
||||
goals = ["What is X?", "What is Y?", "What is Z?"]
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", goals,
|
||||
)
|
||||
values = _get_values(triples, "urn:decompose", TG_SUBAGENT_GOAL)
|
||||
assert set(values) == set(goals)
|
||||
|
||||
def test_label_includes_count(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["a", "b", "c"],
|
||||
)
|
||||
labels = _get_values(triples, "urn:decompose", RDFS_LABEL)
|
||||
assert any("3" in label for label in labels)
|
||||
|
||||
|
||||
class TestFindingTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
assert _has_type(triples, "urn:finding", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:finding", TG_FINDING)
|
||||
assert _has_type(triples, "urn:finding", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_decomposition(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:finding", PROV_WAS_DERIVED_FROM, "urn:decompose") in ts
|
||||
|
||||
def test_includes_goal(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_SUBAGENT_GOAL)
|
||||
assert "What is X?" in values
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "goal",
|
||||
document_id="urn:doc/1",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
|
||||
assert "urn:doc/1" in values
|
||||
|
||||
def test_no_document_when_none(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "goal",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
|
||||
assert values == []
|
||||
|
||||
|
||||
class TestPlanTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
assert _has_type(triples, "urn:plan", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:plan", TG_PLAN_TYPE)
|
||||
|
||||
def test_not_answer_type(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
assert not _has_type(triples, "urn:plan", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_session(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:plan", PROV_WAS_GENERATED_BY, "urn:session") in ts
|
||||
|
||||
def test_includes_steps(self):
|
||||
steps = ["Define X", "Research Y", "Analyse Z"]
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", steps,
|
||||
)
|
||||
values = _get_values(triples, "urn:plan", TG_PLAN_STEP)
|
||||
assert set(values) == set(steps)
|
||||
|
||||
def test_label_includes_count(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["a", "b"],
|
||||
)
|
||||
labels = _get_values(triples, "urn:plan", RDFS_LABEL)
|
||||
assert any("2" in label for label in labels)
|
||||
|
||||
|
||||
class TestStepResultTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
assert _has_type(triples, "urn:step", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:step", TG_STEP_RESULT)
|
||||
assert _has_type(triples, "urn:step", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_plan(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:step", PROV_WAS_DERIVED_FROM, "urn:plan") in ts
|
||||
|
||||
def test_includes_goal(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
values = _get_values(triples, "urn:step", TG_PLAN_STEP)
|
||||
assert "Define X" in values
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "goal",
|
||||
document_id="urn:doc/step",
|
||||
)
|
||||
values = _get_values(triples, "urn:step", TG_DOCUMENT)
|
||||
assert "urn:doc/step" in values
|
||||
|
||||
|
||||
class TestSynthesisTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
)
|
||||
assert _has_type(triples, "urn:synthesis", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:synthesis", TG_SYNTHESIS)
|
||||
assert _has_type(triples, "urn:synthesis", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_previous(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:last-finding",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:synthesis", PROV_WAS_DERIVED_FROM,
|
||||
"urn:last-finding") in ts
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
document_id="urn:doc/synthesis",
|
||||
)
|
||||
values = _get_values(triples, "urn:synthesis", TG_DOCUMENT)
|
||||
assert "urn:doc/synthesis" in values
|
||||
|
||||
def test_label_is_synthesis(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
)
|
||||
labels = _get_values(triples, "urn:synthesis", RDFS_LABEL)
|
||||
assert "Synthesis" in labels
|
||||
Loading…
Add table
Add a link
Reference in a new issue