diff --git a/tests/contract/test_orchestrator_contracts.py b/tests/contract/test_orchestrator_contracts.py new file mode 100644 index 00000000..ab168ece --- /dev/null +++ b/tests/contract/test_orchestrator_contracts.py @@ -0,0 +1,177 @@ +""" +Contract tests for orchestrator message schemas. + +Verifies that AgentRequest/AgentStep with orchestration fields +serialise and deserialise correctly through the Pulsar schema layer. +""" + +import pytest +import json + +from trustgraph.schema import AgentRequest, AgentStep, PlanStep + + +@pytest.mark.contract +class TestOrchestrationFieldContracts: + """Contract tests for orchestration fields on AgentRequest.""" + + def test_agent_request_orchestration_fields_roundtrip(self): + req = AgentRequest( + question="Test question", + user="testuser", + collection="default", + correlation_id="corr-123", + parent_session_id="parent-sess", + subagent_goal="What is X?", + expected_siblings=4, + pattern="react", + task_type="research", + framing="Focus on accuracy", + conversation_id="conv-456", + ) + + assert req.correlation_id == "corr-123" + assert req.parent_session_id == "parent-sess" + assert req.subagent_goal == "What is X?" + assert req.expected_siblings == 4 + assert req.pattern == "react" + assert req.task_type == "research" + assert req.framing == "Focus on accuracy" + assert req.conversation_id == "conv-456" + + def test_agent_request_orchestration_fields_default_empty(self): + req = AgentRequest( + question="Test question", + user="testuser", + ) + + assert req.correlation_id == "" + assert req.parent_session_id == "" + assert req.subagent_goal == "" + assert req.expected_siblings == 0 + assert req.pattern == "" + assert req.task_type == "" + assert req.framing == "" + + +@pytest.mark.contract +class TestSubagentCompletionStepContract: + """Contract tests for subagent-completion step type.""" + + def test_subagent_completion_step_fields(self): + step = AgentStep( + thought="Subagent completed", + action="complete", + arguments={}, + observation="The answer text", + step_type="subagent-completion", + ) + + assert step.step_type == "subagent-completion" + assert step.observation == "The answer text" + assert step.thought == "Subagent completed" + assert step.action == "complete" + + def test_subagent_completion_in_request_history(self): + step = AgentStep( + thought="Subagent completed", + action="complete", + arguments={}, + observation="answer", + step_type="subagent-completion", + ) + req = AgentRequest( + question="goal", + user="testuser", + correlation_id="corr-123", + history=[step], + ) + + assert len(req.history) == 1 + assert req.history[0].step_type == "subagent-completion" + assert req.history[0].observation == "answer" + + +@pytest.mark.contract +class TestSynthesisStepContract: + """Contract tests for synthesis step type with subagent_results.""" + + def test_synthesis_step_with_results(self): + results = {"goal-a": "answer-a", "goal-b": "answer-b"} + step = AgentStep( + thought="All subagents completed", + action="aggregate", + arguments={}, + observation=json.dumps(results), + step_type="synthesise", + subagent_results=results, + ) + + assert step.step_type == "synthesise" + assert step.subagent_results == results + assert json.loads(step.observation) == results + + def test_synthesis_request_matches_supervisor_expectations(self): + """The synthesis request built by the aggregator must be + recognisable by SupervisorPattern._synthesise().""" + results = {"goal-a": "answer-a", "goal-b": "answer-b"} + step = AgentStep( + thought="All subagents completed", + action="aggregate", + arguments={}, + observation=json.dumps(results), + step_type="synthesise", + subagent_results=results, + ) + + req = AgentRequest( + question="Original question", + user="testuser", + pattern="supervisor", + correlation_id="", + session_id="parent-sess", + history=[step], + ) + + # SupervisorPattern checks for step_type='synthesise' with + # subagent_results + has_results = bool( + req.history + and any( + getattr(h, 'step_type', '') == 'synthesise' + and getattr(h, 'subagent_results', None) + for h in req.history + ) + ) + assert has_results + + # Pattern must be supervisor + assert req.pattern == "supervisor" + + # Correlation ID must be empty (not re-intercepted) + assert req.correlation_id == "" + + +@pytest.mark.contract +class TestPlanStepContract: + """Contract tests for plan steps in history.""" + + def test_plan_step_in_history(self): + plan = [ + PlanStep(goal="Step 1", tool_hint="knowledge-query", + depends_on=[], status="completed", result="done"), + PlanStep(goal="Step 2", tool_hint="", + depends_on=[0], status="pending", result=""), + ] + step = AgentStep( + thought="Created plan", + action="plan", + step_type="plan", + plan=plan, + ) + + assert step.step_type == "plan" + assert len(step.plan) == 2 + assert step.plan[0].goal == "Step 1" + assert step.plan[0].status == "completed" + assert step.plan[1].depends_on == [0] diff --git a/tests/contract/test_provenance_wire_format.py b/tests/contract/test_provenance_wire_format.py new file mode 100644 index 00000000..de195f68 --- /dev/null +++ b/tests/contract/test_provenance_wire_format.py @@ -0,0 +1,129 @@ +""" +Contract tests for provenance triple wire format — verifies that triples +built by the provenance library can be parsed by the explainability API +through the wire format conversion. +""" + +import pytest + +from trustgraph.schema import IRI, LITERAL + +from trustgraph.provenance import ( + agent_decomposition_triples, + agent_finding_triples, + agent_plan_triples, + agent_step_result_triples, + agent_synthesis_triples, +) + +from trustgraph.api.explainability import ( + ExplainEntity, + Decomposition, + Finding, + Plan, + StepResult, + Synthesis, + wire_triples_to_tuples, +) + + +def _triples_to_wire(triples): + """Convert provenance Triple objects to the wire format dicts + that the gateway/socket client would produce.""" + wire = [] + for t in triples: + entry = { + "s": _term_to_wire(t.s), + "p": _term_to_wire(t.p), + "o": _term_to_wire(t.o), + } + wire.append(entry) + return wire + + +def _term_to_wire(term): + """Convert a Term to wire format dict.""" + if term.type == IRI: + return {"t": "i", "i": term.iri} + elif term.type == LITERAL: + return {"t": "l", "v": term.value} + return {"t": "l", "v": str(term)} + + +def _roundtrip(triples, uri): + """Convert triples through wire format and parse via from_triples.""" + wire = _triples_to_wire(triples) + tuples = wire_triples_to_tuples(wire) + return ExplainEntity.from_triples(uri, tuples) + + +@pytest.mark.contract +class TestDecompositionWireFormat: + + def test_roundtrip(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", + ["What is X?", "What is Y?"], + ) + entity = _roundtrip(triples, "urn:decompose") + + assert isinstance(entity, Decomposition) + assert set(entity.goals) == {"What is X?", "What is Y?"} + + +@pytest.mark.contract +class TestFindingWireFormat: + + def test_roundtrip(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "What is X?", + document_id="urn:doc/finding", + ) + entity = _roundtrip(triples, "urn:finding") + + assert isinstance(entity, Finding) + assert entity.goal == "What is X?" + assert entity.document == "urn:doc/finding" + + +@pytest.mark.contract +class TestPlanWireFormat: + + def test_roundtrip(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", + ["Step 1", "Step 2", "Step 3"], + ) + entity = _roundtrip(triples, "urn:plan") + + assert isinstance(entity, Plan) + assert set(entity.steps) == {"Step 1", "Step 2", "Step 3"} + + +@pytest.mark.contract +class TestStepResultWireFormat: + + def test_roundtrip(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "Define X", + document_id="urn:doc/step", + ) + entity = _roundtrip(triples, "urn:step") + + assert isinstance(entity, StepResult) + assert entity.step == "Define X" + assert entity.document == "urn:doc/step" + + +@pytest.mark.contract +class TestSynthesisWireFormat: + + def test_roundtrip(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:previous", + document_id="urn:doc/synthesis", + ) + entity = _roundtrip(triples, "urn:synthesis") + + assert isinstance(entity, Synthesis) + assert entity.document == "urn:doc/synthesis" diff --git a/tests/unit/test_agent/test_aggregator.py b/tests/unit/test_agent/test_aggregator.py new file mode 100644 index 00000000..afb19499 --- /dev/null +++ b/tests/unit/test_agent/test_aggregator.py @@ -0,0 +1,216 @@ +""" +Unit tests for the Aggregator — tracks fan-out correlations and triggers +synthesis when all subagents complete. +""" + +import time +import pytest + +from trustgraph.schema import AgentRequest, AgentStep + +from trustgraph.agent.orchestrator.aggregator import Aggregator + + +def _make_request(question="Test question", user="testuser", + collection="default", streaming=False, + session_id="parent-session", task_type="research", + framing="test framing", conversation_id="conv-1"): + return AgentRequest( + question=question, + user=user, + collection=collection, + streaming=streaming, + session_id=session_id, + task_type=task_type, + framing=framing, + conversation_id=conversation_id, + ) + + +class TestRegisterFanout: + + def test_stores_correlation_entry(self): + agg = Aggregator() + agg.register_fanout("corr-1", "parent-1", 3) + + assert "corr-1" in agg.correlations + entry = agg.correlations["corr-1"] + assert entry["parent_session_id"] == "parent-1" + assert entry["expected"] == 3 + assert entry["results"] == {} + + def test_stores_request_template(self): + agg = Aggregator() + template = _make_request() + agg.register_fanout("corr-1", "parent-1", 2, + request_template=template) + + entry = agg.correlations["corr-1"] + assert entry["request_template"] is template + + def test_records_creation_time(self): + agg = Aggregator() + before = time.time() + agg.register_fanout("corr-1", "parent-1", 2) + after = time.time() + + created = agg.correlations["corr-1"]["created_at"] + assert before <= created <= after + + +class TestRecordCompletion: + + def test_returns_false_until_all_done(self): + agg = Aggregator() + agg.register_fanout("corr-1", "parent-1", 3) + + assert agg.record_completion("corr-1", "goal-a", "answer-a") is False + assert agg.record_completion("corr-1", "goal-b", "answer-b") is False + assert agg.record_completion("corr-1", "goal-c", "answer-c") is True + + def test_returns_none_for_unknown_correlation(self): + agg = Aggregator() + result = agg.record_completion("unknown", "goal", "answer") + assert result is None + + def test_stores_results_by_goal(self): + agg = Aggregator() + agg.register_fanout("corr-1", "parent-1", 2) + + agg.record_completion("corr-1", "goal-a", "answer-a") + agg.record_completion("corr-1", "goal-b", "answer-b") + + results = agg.correlations["corr-1"]["results"] + assert results["goal-a"] == "answer-a" + assert results["goal-b"] == "answer-b" + + def test_single_subagent(self): + agg = Aggregator() + agg.register_fanout("corr-1", "parent-1", 1) + + assert agg.record_completion("corr-1", "goal-a", "answer") is True + + +class TestGetOriginalRequest: + + def test_peeks_without_consuming(self): + agg = Aggregator() + template = _make_request() + agg.register_fanout("corr-1", "parent-1", 2, + request_template=template) + + result = agg.get_original_request("corr-1") + assert result is template + # Entry still exists + assert "corr-1" in agg.correlations + + def test_returns_none_for_unknown(self): + agg = Aggregator() + assert agg.get_original_request("unknown") is None + + +class TestBuildSynthesisRequest: + + def test_builds_correct_request(self): + agg = Aggregator() + template = _make_request( + question="Original question", + streaming=True, + task_type="risk-assessment", + framing="Assess risks", + ) + agg.register_fanout("corr-1", "parent-1", 2, + request_template=template) + agg.record_completion("corr-1", "goal-a", "answer-a") + agg.record_completion("corr-1", "goal-b", "answer-b") + + req = agg.build_synthesis_request( + "corr-1", + original_question="Original question", + user="testuser", + collection="default", + ) + + assert req.question == "Original question" + assert req.pattern == "supervisor" + assert req.session_id == "parent-1" + assert req.correlation_id == "" # Must be empty + assert req.streaming == True + assert req.task_type == "risk-assessment" + assert req.framing == "Assess risks" + + def test_synthesis_step_in_history(self): + agg = Aggregator() + template = _make_request() + agg.register_fanout("corr-1", "parent-1", 2, + request_template=template) + agg.record_completion("corr-1", "goal-a", "answer-a") + agg.record_completion("corr-1", "goal-b", "answer-b") + + req = agg.build_synthesis_request( + "corr-1", "question", "user", "default", + ) + + # Last history step should be the synthesis step + assert len(req.history) >= 1 + synth_step = req.history[-1] + assert synth_step.step_type == "synthesise" + assert synth_step.subagent_results == { + "goal-a": "answer-a", + "goal-b": "answer-b", + } + + def test_consumes_correlation_entry(self): + agg = Aggregator() + template = _make_request() + agg.register_fanout("corr-1", "parent-1", 1, + request_template=template) + agg.record_completion("corr-1", "goal-a", "answer-a") + + agg.build_synthesis_request( + "corr-1", "question", "user", "default", + ) + + # Entry should be removed + assert "corr-1" not in agg.correlations + + def test_raises_for_unknown_correlation(self): + agg = Aggregator() + with pytest.raises(RuntimeError, match="No results"): + agg.build_synthesis_request( + "unknown", "question", "user", "default", + ) + + +class TestCleanupStale: + + def test_removes_entries_older_than_timeout(self): + agg = Aggregator(timeout=1) + agg.register_fanout("corr-1", "parent-1", 2) + + # Backdate the creation time + agg.correlations["corr-1"]["created_at"] = time.time() - 2 + + stale = agg.cleanup_stale() + assert "corr-1" in stale + assert "corr-1" not in agg.correlations + + def test_keeps_recent_entries(self): + agg = Aggregator(timeout=300) + agg.register_fanout("corr-1", "parent-1", 2) + + stale = agg.cleanup_stale() + assert stale == [] + assert "corr-1" in agg.correlations + + def test_mixed_stale_and_fresh(self): + agg = Aggregator(timeout=1) + agg.register_fanout("stale", "parent-1", 2) + agg.register_fanout("fresh", "parent-2", 2) + + agg.correlations["stale"]["created_at"] = time.time() - 2 + + stale = agg.cleanup_stale() + assert "stale" in stale + assert "stale" not in agg.correlations + assert "fresh" in agg.correlations diff --git a/tests/unit/test_agent/test_completion_dispatch.py b/tests/unit/test_agent/test_completion_dispatch.py new file mode 100644 index 00000000..8c01f126 --- /dev/null +++ b/tests/unit/test_agent/test_completion_dispatch.py @@ -0,0 +1,174 @@ +""" +Unit tests for completion dispatch — verifies that agent_request() in the +orchestrator service correctly intercepts subagent completion messages and +routes them to _handle_subagent_completion. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.schema import AgentRequest, AgentStep + +from trustgraph.agent.orchestrator.aggregator import Aggregator + + +def _make_request(**kwargs): + defaults = dict( + question="Test question", + user="testuser", + collection="default", + ) + defaults.update(kwargs) + return AgentRequest(**defaults) + + +def _make_completion_request(correlation_id, goal, answer): + """Build a completion request as emit_subagent_completion would.""" + step = AgentStep( + thought="Subagent completed", + action="complete", + arguments={}, + observation=answer, + step_type="subagent-completion", + ) + return _make_request( + correlation_id=correlation_id, + parent_session_id="parent-sess", + subagent_goal=goal, + expected_siblings=2, + history=[step], + ) + + +class TestCompletionDetection: + """Test that completion messages are correctly identified.""" + + def test_is_completion_when_correlation_id_and_step_type(self): + req = _make_completion_request("corr-1", "goal-a", "answer-a") + + has_correlation = bool(getattr(req, 'correlation_id', '')) + is_completion = any( + getattr(h, 'step_type', '') == 'subagent-completion' + for h in req.history + ) + + assert has_correlation + assert is_completion + + def test_not_completion_without_correlation_id(self): + step = AgentStep( + step_type="subagent-completion", + observation="answer", + ) + req = _make_request( + correlation_id="", + history=[step], + ) + + has_correlation = bool(getattr(req, 'correlation_id', '')) + assert not has_correlation + + def test_not_completion_without_step_type(self): + step = AgentStep( + step_type="react", + observation="answer", + ) + req = _make_request( + correlation_id="corr-1", + history=[step], + ) + + is_completion = any( + getattr(h, 'step_type', '') == 'subagent-completion' + for h in req.history + ) + assert not is_completion + + def test_not_completion_with_empty_history(self): + req = _make_request( + correlation_id="corr-1", + history=[], + ) + assert not req.history + + +class TestAggregatorIntegration: + """Test the aggregator flow as used by _handle_subagent_completion.""" + + def test_full_completion_flow(self): + """Simulates the flow: register, record completions, build synthesis.""" + agg = Aggregator() + template = _make_request( + question="Original question", + streaming=True, + task_type="risk-assessment", + framing="Assess risks", + session_id="parent-sess", + ) + + # Register fan-out + agg.register_fanout("corr-1", "parent-sess", 2, + request_template=template) + + # First completion — not all done + all_done = agg.record_completion( + "corr-1", "goal-a", "answer-a", + ) + assert all_done is False + + # Second completion — all done + all_done = agg.record_completion( + "corr-1", "goal-b", "answer-b", + ) + assert all_done is True + + # Peek at template + peeked = agg.get_original_request("corr-1") + assert peeked.question == "Original question" + + # Build synthesis request + synth = agg.build_synthesis_request( + "corr-1", + original_question="Original question", + user="testuser", + collection="default", + ) + + # Verify synthesis request + assert synth.pattern == "supervisor" + assert synth.correlation_id == "" + assert synth.session_id == "parent-sess" + assert synth.streaming is True + + # Verify synthesis history has results + synth_steps = [ + s for s in synth.history + if getattr(s, 'step_type', '') == 'synthesise' + ] + assert len(synth_steps) == 1 + assert synth_steps[0].subagent_results == { + "goal-a": "answer-a", + "goal-b": "answer-b", + } + + def test_synthesis_request_not_detected_as_completion(self): + """The synthesis request must not be intercepted as a completion.""" + agg = Aggregator() + template = _make_request(session_id="parent-sess") + agg.register_fanout("corr-1", "parent-sess", 1, + request_template=template) + agg.record_completion("corr-1", "goal", "answer") + + synth = agg.build_synthesis_request( + "corr-1", "question", "user", "default", + ) + + # correlation_id must be empty so it's not intercepted + assert synth.correlation_id == "" + + # Even if we check for completion step, shouldn't match + is_completion = any( + getattr(h, 'step_type', '') == 'subagent-completion' + for h in synth.history + ) + assert not is_completion diff --git a/tests/unit/test_agent/test_explainability_parsing.py b/tests/unit/test_agent/test_explainability_parsing.py new file mode 100644 index 00000000..e09a7f1f --- /dev/null +++ b/tests/unit/test_agent/test_explainability_parsing.py @@ -0,0 +1,162 @@ +""" +Unit tests for explainability API parsing — verifies that from_triples() +correctly dispatches and parses the new orchestrator entity types. +""" + +import pytest + +from trustgraph.api.explainability import ( + ExplainEntity, + Decomposition, + Finding, + Plan, + StepResult, + Synthesis, + Analysis, + Conclusion, + TG_DECOMPOSITION, + TG_FINDING, + TG_PLAN_TYPE, + TG_STEP_RESULT, + TG_SYNTHESIS, + TG_ANSWER_TYPE, + TG_ANALYSIS, + TG_CONCLUSION, + TG_DOCUMENT, + TG_SUBAGENT_GOAL, + TG_PLAN_STEP, + RDF_TYPE, +) + +PROV_ENTITY = "http://www.w3.org/ns/prov#Entity" + + +def _make_triples(uri, types, extras=None): + """Build a list of (s, p, o) tuples for testing.""" + triples = [(uri, RDF_TYPE, t) for t in types] + if extras: + triples.extend((uri, p, o) for p, o in extras) + return triples + + +class TestFromTriplesDispatch: + + def test_dispatches_decomposition(self): + triples = _make_triples("urn:d", [PROV_ENTITY, TG_DECOMPOSITION]) + entity = ExplainEntity.from_triples("urn:d", triples) + assert isinstance(entity, Decomposition) + + def test_dispatches_finding(self): + triples = _make_triples("urn:f", + [PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:f", triples) + assert isinstance(entity, Finding) + + def test_dispatches_plan(self): + triples = _make_triples("urn:p", [PROV_ENTITY, TG_PLAN_TYPE]) + entity = ExplainEntity.from_triples("urn:p", triples) + assert isinstance(entity, Plan) + + def test_dispatches_step_result(self): + triples = _make_triples("urn:sr", + [PROV_ENTITY, TG_STEP_RESULT, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:sr", triples) + assert isinstance(entity, StepResult) + + def test_dispatches_synthesis(self): + triples = _make_triples("urn:s", + [PROV_ENTITY, TG_SYNTHESIS, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:s", triples) + assert isinstance(entity, Synthesis) + + def test_dispatches_analysis_unchanged(self): + triples = _make_triples("urn:a", [PROV_ENTITY, TG_ANALYSIS]) + entity = ExplainEntity.from_triples("urn:a", triples) + assert isinstance(entity, Analysis) + + def test_dispatches_conclusion_unchanged(self): + triples = _make_triples("urn:c", + [PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:c", triples) + assert isinstance(entity, Conclusion) + + def test_finding_takes_precedence_over_synthesis(self): + """Finding has Answer mixin but should dispatch to Finding, not + Synthesis, because Finding is checked first.""" + triples = _make_triples("urn:f", + [PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:f", triples) + assert isinstance(entity, Finding) + assert not isinstance(entity, Synthesis) + + +class TestDecompositionParsing: + + def test_parses_goals(self): + triples = _make_triples("urn:d", [TG_DECOMPOSITION], [ + (TG_SUBAGENT_GOAL, "What is X?"), + (TG_SUBAGENT_GOAL, "What is Y?"), + ]) + entity = Decomposition.from_triples("urn:d", triples) + assert set(entity.goals) == {"What is X?", "What is Y?"} + + def test_entity_type_field(self): + triples = _make_triples("urn:d", [TG_DECOMPOSITION]) + entity = Decomposition.from_triples("urn:d", triples) + assert entity.entity_type == "decomposition" + + def test_empty_goals(self): + triples = _make_triples("urn:d", [TG_DECOMPOSITION]) + entity = Decomposition.from_triples("urn:d", triples) + assert entity.goals == [] + + +class TestFindingParsing: + + def test_parses_goal_and_document(self): + triples = _make_triples("urn:f", [TG_FINDING, TG_ANSWER_TYPE], [ + (TG_SUBAGENT_GOAL, "What is X?"), + (TG_DOCUMENT, "urn:doc/finding"), + ]) + entity = Finding.from_triples("urn:f", triples) + assert entity.goal == "What is X?" + assert entity.document == "urn:doc/finding" + + def test_entity_type_field(self): + triples = _make_triples("urn:f", [TG_FINDING]) + entity = Finding.from_triples("urn:f", triples) + assert entity.entity_type == "finding" + + +class TestPlanParsing: + + def test_parses_steps(self): + triples = _make_triples("urn:p", [TG_PLAN_TYPE], [ + (TG_PLAN_STEP, "Define X"), + (TG_PLAN_STEP, "Research Y"), + (TG_PLAN_STEP, "Analyse Z"), + ]) + entity = Plan.from_triples("urn:p", triples) + assert set(entity.steps) == {"Define X", "Research Y", "Analyse Z"} + + def test_entity_type_field(self): + triples = _make_triples("urn:p", [TG_PLAN_TYPE]) + entity = Plan.from_triples("urn:p", triples) + assert entity.entity_type == "plan" + + +class TestStepResultParsing: + + def test_parses_step_and_document(self): + triples = _make_triples("urn:sr", [TG_STEP_RESULT, TG_ANSWER_TYPE], [ + (TG_PLAN_STEP, "Define X"), + (TG_DOCUMENT, "urn:doc/step"), + ]) + entity = StepResult.from_triples("urn:sr", triples) + assert entity.step == "Define X" + assert entity.document == "urn:doc/step" + + def test_entity_type_field(self): + triples = _make_triples("urn:sr", [TG_STEP_RESULT]) + entity = StepResult.from_triples("urn:sr", triples) + assert entity.entity_type == "step-result" diff --git a/tests/unit/test_agent/test_meta_router.py b/tests/unit/test_agent/test_meta_router.py new file mode 100644 index 00000000..da0c634c --- /dev/null +++ b/tests/unit/test_agent/test_meta_router.py @@ -0,0 +1,289 @@ +""" +Unit tests for the MetaRouter — task type identification and pattern selection. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.agent.orchestrator.meta_router import ( + MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE, +) + + +def _make_config(patterns=None, task_types=None): + """Build a config dict as the config service would provide.""" + config = {} + if patterns: + config["agent-pattern"] = { + pid: json.dumps(pdata) for pid, pdata in patterns.items() + } + if task_types: + config["agent-task-type"] = { + tid: json.dumps(tdata) for tid, tdata in task_types.items() + } + return config + + +def _make_context(prompt_response): + """Build a mock context that returns a mock prompt client.""" + client = AsyncMock() + client.prompt = AsyncMock(return_value=prompt_response) + + def context(service_name): + return client + + return context + + +SAMPLE_PATTERNS = { + "react": {"name": "react", "description": "ReAct pattern"}, + "plan-then-execute": {"name": "plan-then-execute", "description": "Plan pattern"}, + "supervisor": {"name": "supervisor", "description": "Supervisor pattern"}, +} + +SAMPLE_TASK_TYPES = { + "general": { + "name": "general", + "description": "General queries", + "valid_patterns": ["react", "plan-then-execute", "supervisor"], + "framing": "", + }, + "research": { + "name": "research", + "description": "Research queries", + "valid_patterns": ["react", "plan-then-execute"], + "framing": "Focus on gathering information.", + }, + "summarisation": { + "name": "summarisation", + "description": "Summarisation queries", + "valid_patterns": ["react"], + "framing": "Focus on concise synthesis.", + }, +} + + +class TestMetaRouterInit: + + def test_defaults_when_no_config(self): + router = MetaRouter() + assert "react" in router.patterns + assert "general" in router.task_types + + def test_loads_patterns_from_config(self): + config = _make_config(patterns=SAMPLE_PATTERNS) + router = MetaRouter(config=config) + assert set(router.patterns.keys()) == {"react", "plan-then-execute", "supervisor"} + + def test_loads_task_types_from_config(self): + config = _make_config(task_types=SAMPLE_TASK_TYPES) + router = MetaRouter(config=config) + assert set(router.task_types.keys()) == {"general", "research", "summarisation"} + + def test_handles_invalid_json_in_config(self): + config = { + "agent-pattern": {"react": "not valid json"}, + } + router = MetaRouter(config=config) + assert "react" in router.patterns + assert router.patterns["react"]["name"] == "react" + + +class TestIdentifyTaskType: + + @pytest.mark.asyncio + async def test_skips_llm_when_single_task_type(self): + router = MetaRouter() # Only "general" + context = _make_context("should not be called") + + task_type, framing = await router.identify_task_type( + "test question", context, + ) + + assert task_type == "general" + + @pytest.mark.asyncio + async def test_uses_llm_when_multiple_task_types(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("research") + + task_type, framing = await router.identify_task_type( + "Research the topic", context, + ) + + assert task_type == "research" + assert framing == "Focus on gathering information." + + @pytest.mark.asyncio + async def test_handles_llm_returning_quoted_type(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context('"summarisation"') + + task_type, _ = await router.identify_task_type( + "Summarise this", context, + ) + + assert task_type == "summarisation" + + @pytest.mark.asyncio + async def test_falls_back_on_unknown_type(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("nonexistent-type") + + task_type, _ = await router.identify_task_type( + "test question", context, + ) + + assert task_type == DEFAULT_TASK_TYPE + + @pytest.mark.asyncio + async def test_falls_back_on_llm_error(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + + client = AsyncMock() + client.prompt = AsyncMock(side_effect=RuntimeError("LLM down")) + context = lambda name: client + + task_type, _ = await router.identify_task_type( + "test question", context, + ) + + assert task_type == DEFAULT_TASK_TYPE + + +class TestSelectPattern: + + @pytest.mark.asyncio + async def test_skips_llm_when_single_valid_pattern(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("should not be called") + + # summarisation only has ["react"] + pattern = await router.select_pattern( + "Summarise this", "summarisation", context, + ) + + assert pattern == "react" + + @pytest.mark.asyncio + async def test_uses_llm_when_multiple_valid_patterns(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("plan-then-execute") + + # research has ["react", "plan-then-execute"] + pattern = await router.select_pattern( + "Research this", "research", context, + ) + + assert pattern == "plan-then-execute" + + @pytest.mark.asyncio + async def test_respects_valid_patterns_constraint(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + # LLM returns supervisor, but research doesn't allow it + context = _make_context("supervisor") + + pattern = await router.select_pattern( + "Research this", "research", context, + ) + + # Should fall back to first valid pattern + assert pattern == "react" + + @pytest.mark.asyncio + async def test_falls_back_on_llm_error(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + + client = AsyncMock() + client.prompt = AsyncMock(side_effect=RuntimeError("LLM down")) + context = lambda name: client + + # general has ["react", "plan-then-execute", "supervisor"] + pattern = await router.select_pattern( + "test", "general", context, + ) + + # Falls back to first valid pattern + assert pattern == "react" + + @pytest.mark.asyncio + async def test_falls_back_to_default_for_unknown_task_type(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("react") + + # Unknown task type — valid_patterns falls back to all patterns + pattern = await router.select_pattern( + "test", "unknown-type", context, + ) + + assert pattern == "react" + + +class TestRoute: + + @pytest.mark.asyncio + async def test_full_routing_pipeline(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + + # Mock context where prompt returns different values per call + client = AsyncMock() + call_count = 0 + + async def mock_prompt(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return "research" # task type + return "plan-then-execute" # pattern + + client.prompt = mock_prompt + context = lambda name: client + + pattern, task_type, framing = await router.route( + "Research the relationships", context, + ) + + assert task_type == "research" + assert pattern == "plan-then-execute" + assert framing == "Focus on gathering information." diff --git a/tests/unit/test_agent/test_pattern_base_subagent.py b/tests/unit/test_agent/test_pattern_base_subagent.py new file mode 100644 index 00000000..1523b592 --- /dev/null +++ b/tests/unit/test_agent/test_pattern_base_subagent.py @@ -0,0 +1,144 @@ +""" +Unit tests for PatternBase subagent helpers — is_subagent() and +emit_subagent_completion(). +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock +from dataclasses import dataclass + +from trustgraph.schema import AgentRequest + +from trustgraph.agent.orchestrator.pattern_base import PatternBase + + +@dataclass +class MockProcessor: + """Minimal processor mock for PatternBase.""" + pass + + +def _make_request(**kwargs): + defaults = dict( + question="Test question", + user="testuser", + collection="default", + ) + defaults.update(kwargs) + return AgentRequest(**defaults) + + +def _make_pattern(): + return PatternBase(MockProcessor()) + + +class TestIsSubagent: + + def test_returns_true_when_correlation_id_set(self): + pattern = _make_pattern() + request = _make_request(correlation_id="corr-123") + assert pattern.is_subagent(request) is True + + def test_returns_false_when_correlation_id_empty(self): + pattern = _make_pattern() + request = _make_request(correlation_id="") + assert pattern.is_subagent(request) is False + + def test_returns_false_when_correlation_id_missing(self): + pattern = _make_pattern() + request = _make_request() + assert pattern.is_subagent(request) is False + + +class TestEmitSubagentCompletion: + + @pytest.mark.asyncio + async def test_calls_next_with_completion_request(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + parent_session_id="parent-sess", + subagent_goal="What is X?", + expected_siblings=4, + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "The answer is Y", + ) + + next_fn.assert_called_once() + completion_req = next_fn.call_args[0][0] + assert isinstance(completion_req, AgentRequest) + + @pytest.mark.asyncio + async def test_completion_has_correct_step_type(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + subagent_goal="What is X?", + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "answer text", + ) + + completion_req = next_fn.call_args[0][0] + assert len(completion_req.history) == 1 + step = completion_req.history[0] + assert step.step_type == "subagent-completion" + + @pytest.mark.asyncio + async def test_completion_carries_answer_in_observation(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + subagent_goal="What is X?", + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "The answer is Y", + ) + + completion_req = next_fn.call_args[0][0] + step = completion_req.history[0] + assert step.observation == "The answer is Y" + + @pytest.mark.asyncio + async def test_completion_preserves_correlation_fields(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + parent_session_id="parent-sess", + subagent_goal="What is X?", + expected_siblings=4, + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "answer", + ) + + completion_req = next_fn.call_args[0][0] + assert completion_req.correlation_id == "corr-123" + assert completion_req.parent_session_id == "parent-sess" + assert completion_req.subagent_goal == "What is X?" + assert completion_req.expected_siblings == 4 + + @pytest.mark.asyncio + async def test_completion_has_empty_pattern(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + subagent_goal="goal", + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "answer", + ) + + completion_req = next_fn.call_args[0][0] + assert completion_req.pattern == "" diff --git a/tests/unit/test_agent/test_provenance_triples.py b/tests/unit/test_agent/test_provenance_triples.py new file mode 100644 index 00000000..ed14d6ae --- /dev/null +++ b/tests/unit/test_agent/test_provenance_triples.py @@ -0,0 +1,226 @@ +""" +Unit tests for orchestrator provenance triple builders. +""" + +import pytest + +from trustgraph.provenance import ( + agent_decomposition_triples, + agent_finding_triples, + agent_plan_triples, + agent_step_result_triples, + agent_synthesis_triples, +) + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, + TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT, + TG_SUBAGENT_GOAL, TG_PLAN_STEP, +) + + +def _triple_set(triples): + """Convert triples to a set of (s_iri, p_iri, o_value) for easy assertion.""" + result = set() + for t in triples: + s = t.s.iri + p = t.p.iri + o = t.o.iri if t.o.iri else t.o.value + result.add((s, p, o)) + return result + + +def _has_type(triples, uri, rdf_type): + """Check if a URI has a given rdf:type in the triples.""" + return (uri, RDF_TYPE, rdf_type) in _triple_set(triples) + + +def _get_values(triples, uri, predicate): + """Get all object values for a given subject + predicate.""" + ts = _triple_set(triples) + return [o for s, p, o in ts if s == uri and p == predicate] + + +class TestDecompositionTriples: + + def test_has_correct_types(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", ["goal-a", "goal-b"], + ) + assert _has_type(triples, "urn:decompose", PROV_ENTITY) + assert _has_type(triples, "urn:decompose", TG_DECOMPOSITION) + + def test_not_answer_type(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", ["goal-a"], + ) + assert not _has_type(triples, "urn:decompose", TG_ANSWER_TYPE) + + def test_links_to_session(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", ["goal-a"], + ) + ts = _triple_set(triples) + assert ("urn:decompose", PROV_WAS_GENERATED_BY, "urn:session") in ts + + def test_includes_goals(self): + goals = ["What is X?", "What is Y?", "What is Z?"] + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", goals, + ) + values = _get_values(triples, "urn:decompose", TG_SUBAGENT_GOAL) + assert set(values) == set(goals) + + def test_label_includes_count(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", ["a", "b", "c"], + ) + labels = _get_values(triples, "urn:decompose", RDFS_LABEL) + assert any("3" in label for label in labels) + + +class TestFindingTriples: + + def test_has_correct_types(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "What is X?", + ) + assert _has_type(triples, "urn:finding", PROV_ENTITY) + assert _has_type(triples, "urn:finding", TG_FINDING) + assert _has_type(triples, "urn:finding", TG_ANSWER_TYPE) + + def test_links_to_decomposition(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "What is X?", + ) + ts = _triple_set(triples) + assert ("urn:finding", PROV_WAS_DERIVED_FROM, "urn:decompose") in ts + + def test_includes_goal(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "What is X?", + ) + values = _get_values(triples, "urn:finding", TG_SUBAGENT_GOAL) + assert "What is X?" in values + + def test_includes_document_when_provided(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "goal", + document_id="urn:doc/1", + ) + values = _get_values(triples, "urn:finding", TG_DOCUMENT) + assert "urn:doc/1" in values + + def test_no_document_when_none(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "goal", + ) + values = _get_values(triples, "urn:finding", TG_DOCUMENT) + assert values == [] + + +class TestPlanTriples: + + def test_has_correct_types(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", ["step-a"], + ) + assert _has_type(triples, "urn:plan", PROV_ENTITY) + assert _has_type(triples, "urn:plan", TG_PLAN_TYPE) + + def test_not_answer_type(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", ["step-a"], + ) + assert not _has_type(triples, "urn:plan", TG_ANSWER_TYPE) + + def test_links_to_session(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", ["step-a"], + ) + ts = _triple_set(triples) + assert ("urn:plan", PROV_WAS_GENERATED_BY, "urn:session") in ts + + def test_includes_steps(self): + steps = ["Define X", "Research Y", "Analyse Z"] + triples = agent_plan_triples( + "urn:plan", "urn:session", steps, + ) + values = _get_values(triples, "urn:plan", TG_PLAN_STEP) + assert set(values) == set(steps) + + def test_label_includes_count(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", ["a", "b"], + ) + labels = _get_values(triples, "urn:plan", RDFS_LABEL) + assert any("2" in label for label in labels) + + +class TestStepResultTriples: + + def test_has_correct_types(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "Define X", + ) + assert _has_type(triples, "urn:step", PROV_ENTITY) + assert _has_type(triples, "urn:step", TG_STEP_RESULT) + assert _has_type(triples, "urn:step", TG_ANSWER_TYPE) + + def test_links_to_plan(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "Define X", + ) + ts = _triple_set(triples) + assert ("urn:step", PROV_WAS_DERIVED_FROM, "urn:plan") in ts + + def test_includes_goal(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "Define X", + ) + values = _get_values(triples, "urn:step", TG_PLAN_STEP) + assert "Define X" in values + + def test_includes_document_when_provided(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "goal", + document_id="urn:doc/step", + ) + values = _get_values(triples, "urn:step", TG_DOCUMENT) + assert "urn:doc/step" in values + + +class TestSynthesisTriples: + + def test_has_correct_types(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:previous", + ) + assert _has_type(triples, "urn:synthesis", PROV_ENTITY) + assert _has_type(triples, "urn:synthesis", TG_SYNTHESIS) + assert _has_type(triples, "urn:synthesis", TG_ANSWER_TYPE) + + def test_links_to_previous(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:last-finding", + ) + ts = _triple_set(triples) + assert ("urn:synthesis", PROV_WAS_DERIVED_FROM, + "urn:last-finding") in ts + + def test_includes_document_when_provided(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:previous", + document_id="urn:doc/synthesis", + ) + values = _get_values(triples, "urn:synthesis", TG_DOCUMENT) + assert "urn:doc/synthesis" in values + + def test_label_is_synthesis(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:previous", + ) + labels = _get_values(triples, "urn:synthesis", RDFS_LABEL) + assert "Synthesis" in labels