diff --git a/tests/unit/test_agent/test_orchestrator_provenance_integration.py b/tests/unit/test_agent/test_orchestrator_provenance_integration.py new file mode 100644 index 00000000..96d41259 --- /dev/null +++ b/tests/unit/test_agent/test_orchestrator_provenance_integration.py @@ -0,0 +1,655 @@ +""" +Integration tests for agent-orchestrator provenance chains. + +Tests all three patterns by calling iterate() with mocked dependencies +and verifying the explain events emitted via respond(). + +Provenance chains: + React: session → iteration → (observation or final) + Plan: session → plan → step-result(s) → synthesis + Supervisor: session → decomposition → finding(s) → synthesis +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from dataclasses import dataclass, field + +from trustgraph.schema import ( + AgentRequest, AgentResponse, AgentStep, PlanStep, +) + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + GRAPH_RETRIEVAL, +) + +# Agent provenance type constants +from trustgraph.provenance.namespaces import ( + TG_AGENT_QUESTION, + TG_ANALYSIS, + TG_TOOL_USE, + TG_OBSERVATION_TYPE, + TG_CONCLUSION, + TG_DECOMPOSITION, + TG_FINDING, + TG_PLAN_TYPE, + TG_STEP_RESULT, + TG_SYNTHESIS as TG_AGENT_SYNTHESIS, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +def collect_explain_events(respond_mock): + """Extract explain events from a respond mock's call history.""" + events = [] + for call in respond_mock.call_args_list: + resp = call[0][0] + if isinstance(resp, AgentResponse) and resp.chunk_type == "explain": + events.append({ + "explain_id": resp.explain_id, + "explain_graph": resp.explain_graph, + "triples": resp.explain_triples, + }) + return events + + +# --------------------------------------------------------------------------- +# Mock processor +# --------------------------------------------------------------------------- + +def make_mock_processor(tools=None): + """Build a mock processor with the minimal interface patterns need.""" + processor = MagicMock() + processor.max_iterations = 10 + processor.save_answer_content = AsyncMock() + + # provenance_session_uri must return a real URI + def mock_session_uri(session_id): + return f"urn:trustgraph:agent:session:{session_id}" + processor.provenance_session_uri.side_effect = mock_session_uri + + # Agent with tools + agent = MagicMock() + agent.tools = tools or {} + agent.additional_context = "" + processor.agent = agent + + # Aggregator for supervisor + processor.aggregator = MagicMock() + + return processor + + +def make_mock_flow(): + """Build a mock flow that returns async mock producers.""" + producers = {} + + def flow_factory(name): + if name not in producers: + producers[name] = AsyncMock() + return producers[name] + + flow = MagicMock(side_effect=flow_factory) + flow._producers = producers + return flow + + +def make_base_request(**kwargs): + """Build a minimal AgentRequest.""" + defaults = dict( + question="What is quantum computing?", + state="", + group=[], + history=[], + user="testuser", + collection="default", + streaming=False, + session_id="test-session-123", + conversation_id="", + pattern="react", + task_type="", + framing="", + correlation_id="", + parent_session_id="", + subagent_goal="", + expected_siblings=0, + ) + defaults.update(kwargs) + return AgentRequest(**defaults) + + +# --------------------------------------------------------------------------- +# React pattern tests +# --------------------------------------------------------------------------- + +class TestReactPatternProvenance: + """ + React pattern chain: session → iteration → final + (single iteration ending in Final answer) + """ + + @pytest.mark.asyncio + async def test_single_iteration_final_answer(self): + """ + A single react iteration that produces a Final answer should emit: + session, iteration, final — in that order. + """ + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.react.types import Action, Final + + processor = make_mock_processor() + pattern = ReactPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + request = make_base_request() + + # Mock AgentManager.react to call on_action then return Final + with patch( + 'trustgraph.agent.orchestrator.react_pattern.AgentManager' + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + final = Final( + thought="I know the answer", + final="Quantum computing uses qubits.", + ) + + async def mock_react(question, history, think, observe, answer, + context, streaming, on_action): + # Simulate the on_action callback before returning Final + if on_action: + await on_action(Action( + thought="I know the answer", + name="final", + arguments={}, + observation="", + )) + return final + + mock_am.react.side_effect = mock_react + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have 3 events: session, iteration, final + assert len(events) == 3, ( + f"Expected 3 explain events (session, iteration, final), " + f"got {len(events)}: {[e['explain_id'] for e in events]}" + ) + + # Check types + assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION) + assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS) + assert has_type(events[2]["triples"], events[2]["explain_id"], TG_CONCLUSION) + + # Check derivation chain + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + + uris = [e["explain_id"] for e in events] + + # iteration derives from session + assert derived_from(all_triples, uris[1]) == uris[0] + # final derives from session (first iteration, no prior observation) + assert derived_from(all_triples, uris[2]) == uris[0] + + @pytest.mark.asyncio + async def test_iteration_with_tool_call(self): + """ + A react iteration that calls a tool (not Final) should emit: + session, iteration, observation — then call next() for continuation. + """ + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.react.types import Action + + # Create a mock tool + mock_tool = MagicMock() + mock_tool.name = "knowledge-query" + mock_tool.description = "Query the knowledge base" + mock_tool.arguments = [] + mock_tool.groups = [] + mock_tool.states = {} + mock_tool_impl = AsyncMock(return_value="The answer is 42") + mock_tool.implementation = MagicMock(return_value=mock_tool_impl) + + processor = make_mock_processor( + tools={"knowledge-query": mock_tool} + ) + pattern = ReactPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + request = make_base_request() + + action = Action( + thought="I need to look this up", + name="knowledge-query", + arguments={"question": "What is quantum computing?"}, + observation="Quantum computing uses qubits.", + ) + + with patch( + 'trustgraph.agent.orchestrator.react_pattern.AgentManager' + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + async def mock_react(question, history, think, observe, answer, + context, streaming, on_action): + if on_action: + await on_action(action) + return action + + mock_am.react.side_effect = mock_react + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have 3 events: session, iteration, observation + assert len(events) == 3, ( + f"Expected 3 explain events (session, iteration, observation), " + f"got {len(events)}: {[e['explain_id'] for e in events]}" + ) + + assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION) + assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS) + assert has_type(events[2]["triples"], events[2]["explain_id"], TG_OBSERVATION_TYPE) + + # next() should have been called to continue the loop + assert next_fn.called + + @pytest.mark.asyncio + async def test_all_triples_in_retrieval_graph(self): + """All explain triples should be in urn:graph:retrieval.""" + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.react.types import Action, Final + + processor = make_mock_processor() + pattern = ReactPattern(processor) + respond = AsyncMock() + flow = make_mock_flow() + + with patch( + 'trustgraph.agent.orchestrator.react_pattern.AgentManager' + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + async def mock_react(question, history, think, observe, answer, + context, streaming, on_action): + if on_action: + await on_action(Action( + thought="done", name="final", + arguments={}, observation="", + )) + return Final(thought="done", final="answer") + + mock_am.react.side_effect = mock_react + await pattern.iterate( + make_base_request(), respond, AsyncMock(), flow, + ) + + for event in collect_explain_events(respond): + for t in event["triples"]: + assert t.g == GRAPH_RETRIEVAL + + +# --------------------------------------------------------------------------- +# Plan-then-execute pattern tests +# --------------------------------------------------------------------------- + +class TestPlanPatternProvenance: + """ + Plan pattern chain: + Planning iteration: session → plan + Execution iterations: step-result(s) → synthesis + """ + + @pytest.mark.asyncio + async def test_planning_iteration_emits_session_and_plan(self): + """ + The first iteration (planning) should emit: + session, plan — then call next() with the plan in history. + """ + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + + processor = make_mock_processor() + pattern = PlanThenExecutePattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt client for plan creation + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = [ + {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []}, + {"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]}, + ] + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + request = make_base_request(pattern="plan") + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have 2 events: session, plan + assert len(events) == 2, ( + f"Expected 2 explain events (session, plan), " + f"got {len(events)}: {[e['explain_id'] for e in events]}" + ) + + assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION) + assert has_type(events[1]["triples"], events[1]["explain_id"], TG_PLAN_TYPE) + + # Plan should derive from session + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"] + + # next() should have been called with plan in history + assert next_fn.called + + @pytest.mark.asyncio + async def test_execution_iteration_emits_step_result(self): + """ + An execution iteration should emit a step-result event. + """ + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + + # Create a mock tool + mock_tool = MagicMock() + mock_tool.name = "knowledge-query" + mock_tool.description = "Query KB" + mock_tool.arguments = [] + mock_tool.groups = [] + mock_tool.states = {} + mock_tool_impl = AsyncMock(return_value="Found the answer") + mock_tool.implementation = MagicMock(return_value=mock_tool_impl) + + processor = make_mock_processor( + tools={"knowledge-query": mock_tool} + ) + pattern = PlanThenExecutePattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt for step execution + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = { + "tool": "knowledge-query", + "arguments": {"question": "quantum computing"}, + } + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + # Request with plan already in history (second iteration) + plan_step = AgentStep( + thought="Created plan", + action="plan", + arguments={}, + observation="[]", + step_type="plan", + plan=[ + PlanStep(goal="Find info", tool_hint="knowledge-query", + depends_on=[], status="pending", result=""), + ], + ) + request = make_base_request( + pattern="plan", + history=[plan_step], + ) + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have step-result (no session on iteration > 1) + step_events = [ + e for e in events + if has_type(e["triples"], e["explain_id"], TG_STEP_RESULT) + ] + assert len(step_events) == 1, ( + f"Expected 1 step-result event, got {len(step_events)}" + ) + + @pytest.mark.asyncio + async def test_synthesis_after_all_steps_complete(self): + """ + When all plan steps are completed, synthesis should be emitted. + """ + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + + processor = make_mock_processor() + pattern = PlanThenExecutePattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt for synthesis + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = "The synthesised answer." + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + # Request with all steps completed + exec_step = AgentStep( + thought="Executing step", + action="knowledge-query", + arguments={}, + observation="Result", + step_type="execute", + plan=[ + PlanStep(goal="Find info", tool_hint="knowledge-query", + depends_on=[], status="completed", result="Found it"), + ], + ) + request = make_base_request( + pattern="plan", + history=[exec_step], + ) + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have synthesis event + synth_events = [ + e for e in events + if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS) + ] + assert len(synth_events) == 1, ( + f"Expected 1 synthesis event, got {len(synth_events)}" + ) + + +# --------------------------------------------------------------------------- +# Supervisor pattern tests +# --------------------------------------------------------------------------- + +class TestSupervisorPatternProvenance: + """ + Supervisor pattern chain: + Decompose: session → decomposition + (Fan-out to subagents happens externally) + Synthesise: synthesis (derives from findings) + """ + + @pytest.mark.asyncio + async def test_decompose_emits_session_and_decomposition(self): + """ + The decompose phase should emit: session, decomposition. + """ + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + processor = make_mock_processor() + pattern = SupervisorPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt for decomposition + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = [ + "What is quantum computing?", + "What are qubits?", + ] + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + request = make_base_request(pattern="supervisor") + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have 2 events: session, decomposition + assert len(events) == 2, ( + f"Expected 2 explain events (session, decomposition), " + f"got {len(events)}: {[e['explain_id'] for e in events]}" + ) + + assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION) + assert has_type(events[1]["triples"], events[1]["explain_id"], TG_DECOMPOSITION) + + # Decomposition derives from session + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"] + + @pytest.mark.asyncio + async def test_synthesis_emits_after_subagent_results(self): + """ + When subagent results arrive, synthesis should be emitted. + """ + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + processor = make_mock_processor() + pattern = SupervisorPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt for synthesis + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = "The combined answer." + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + # Request with subagent results in history + synth_step = AgentStep( + thought="", + action="synthesise", + arguments={}, + observation="", + step_type="synthesise", + subagent_results={ + "What is quantum computing?": "It uses qubits", + "What are qubits?": "Quantum bits", + }, + ) + request = make_base_request( + pattern="supervisor", + history=[synth_step], + ) + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have synthesis event (no session on iteration > 1) + synth_events = [ + e for e in events + if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS) + ] + assert len(synth_events) == 1 + + @pytest.mark.asyncio + async def test_decompose_fans_out_subagents(self): + """The decompose phase should call next() for each subagent goal.""" + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + processor = make_mock_processor() + pattern = SupervisorPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"] + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + request = make_base_request(pattern="supervisor") + + await pattern.iterate(request, respond, next_fn, flow) + + # 3 subagent requests fanned out + assert next_fn.call_count == 3 diff --git a/tests/unit/test_provenance/test_graph_rag_chain.py b/tests/unit/test_provenance/test_graph_rag_chain.py new file mode 100644 index 00000000..657384b0 --- /dev/null +++ b/tests/unit/test_provenance/test_graph_rag_chain.py @@ -0,0 +1,295 @@ +""" +Structural test for the graph-rag provenance chain. + +Verifies that a complete graph-rag query produces the expected +provenance chain: + + question → grounding → exploration → focus → synthesis + +Each step must: +- Have the correct rdf:type +- Link to its predecessor via prov:wasDerivedFrom +- Carry expected domain-specific data +""" + +import pytest + +from trustgraph.provenance.triples import ( + question_triples, + grounding_triples, + exploration_triples, + focus_triples, + synthesis_triples, +) +from trustgraph.provenance.uris import ( + question_uri, + grounding_uri, + exploration_uri, + focus_uri, + synthesis_uri, +) +from trustgraph.provenance.namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ENTITY, PROV_WAS_DERIVED_FROM, + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_GRAPH_RAG_QUESTION, TG_ANSWER_TYPE, + TG_QUERY, TG_CONCEPT, TG_ENTITY, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_DOCUMENT, + PROV_STARTED_AT_TIME, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +SESSION_ID = "test-session-1234" + + +def find_triple(triples, predicate, subject=None): + """Find first triple matching predicate (and optionally subject).""" + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + """Find all triples matching predicate (and optionally subject).""" + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + """Check if subject has the given rdf:type.""" + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + """Get the wasDerivedFrom target URI for a subject.""" + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +# --------------------------------------------------------------------------- +# Build the full chain +# --------------------------------------------------------------------------- + +@pytest.fixture +def chain(): + """Build all provenance triples for a complete graph-rag query.""" + q_uri = question_uri(SESSION_ID) + gnd_uri = grounding_uri(SESSION_ID) + exp_uri = exploration_uri(SESSION_ID) + foc_uri = focus_uri(SESSION_ID) + syn_uri = synthesis_uri(SESSION_ID) + + q = question_triples(q_uri, "What is quantum computing?", "2026-01-01T00:00:00Z") + gnd = grounding_triples(gnd_uri, q_uri, ["quantum", "computing"]) + exp = exploration_triples( + exp_uri, gnd_uri, edge_count=42, + entities=["urn:entity:1", "urn:entity:2"], + ) + foc = focus_triples( + foc_uri, exp_uri, + selected_edges_with_reasoning=[ + { + "edge": ( + "http://example.com/QuantumComputing", + "http://schema.org/relatedTo", + "http://example.com/Physics", + ), + "reasoning": "Directly relevant to the query", + }, + { + "edge": ( + "http://example.com/QuantumComputing", + "http://schema.org/name", + "Quantum Computing", + ), + "reasoning": "Provides the entity label", + }, + ], + session_id=SESSION_ID, + ) + syn = synthesis_triples(syn_uri, foc_uri, document_id="urn:doc:answer-1") + + return { + "uris": { + "question": q_uri, + "grounding": gnd_uri, + "exploration": exp_uri, + "focus": foc_uri, + "synthesis": syn_uri, + }, + "triples": { + "question": q, + "grounding": gnd, + "exploration": exp, + "focus": foc, + "synthesis": syn, + }, + "all": q + gnd + exp + foc + syn, + } + + +# --------------------------------------------------------------------------- +# Chain structure tests +# --------------------------------------------------------------------------- + +class TestGraphRagProvenanceChain: + """Verify the full question → grounding → exploration → focus → synthesis chain.""" + + def test_chain_has_five_stages(self, chain): + """Each stage should produce at least some triples.""" + for stage in ["question", "grounding", "exploration", "focus", "synthesis"]: + assert len(chain["triples"][stage]) > 0, f"{stage} produced no triples" + + def test_derivation_chain(self, chain): + """ + The wasDerivedFrom links must form: + grounding → question, exploration → grounding, + focus → exploration, synthesis → focus. + """ + uris = chain["uris"] + all_triples = chain["all"] + + assert derived_from(all_triples, uris["grounding"]) == uris["question"] + assert derived_from(all_triples, uris["exploration"]) == uris["grounding"] + assert derived_from(all_triples, uris["focus"]) == uris["exploration"] + assert derived_from(all_triples, uris["synthesis"]) == uris["focus"] + + def test_question_has_no_parent(self, chain): + """The root question should not derive from anything (no parent_uri).""" + uris = chain["uris"] + all_triples = chain["all"] + assert derived_from(all_triples, uris["question"]) is None + + def test_question_with_parent(self): + """When a parent_uri is given, question should derive from it.""" + q_uri = question_uri("child-session") + parent = "urn:trustgraph:agent:iteration:parent" + q = question_triples(q_uri, "sub-query", "2026-01-01T00:00:00Z", + parent_uri=parent) + assert derived_from(q, q_uri) == parent + + +# --------------------------------------------------------------------------- +# Type annotation tests +# --------------------------------------------------------------------------- + +class TestGraphRagProvenanceTypes: + """Each stage must have the correct rdf:type annotations.""" + + def test_question_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["question"] + assert has_type(triples, uris["question"], PROV_ENTITY) + assert has_type(triples, uris["question"], TG_GRAPH_RAG_QUESTION) + + def test_grounding_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["grounding"] + assert has_type(triples, uris["grounding"], PROV_ENTITY) + assert has_type(triples, uris["grounding"], TG_GROUNDING) + + def test_exploration_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["exploration"] + assert has_type(triples, uris["exploration"], PROV_ENTITY) + assert has_type(triples, uris["exploration"], TG_EXPLORATION) + + def test_focus_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["focus"] + assert has_type(triples, uris["focus"], PROV_ENTITY) + assert has_type(triples, uris["focus"], TG_FOCUS) + + def test_synthesis_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["synthesis"] + assert has_type(triples, uris["synthesis"], PROV_ENTITY) + assert has_type(triples, uris["synthesis"], TG_SYNTHESIS) + assert has_type(triples, uris["synthesis"], TG_ANSWER_TYPE) + + +# --------------------------------------------------------------------------- +# Domain-specific content tests +# --------------------------------------------------------------------------- + +class TestGraphRagProvenanceContent: + """Each stage should carry the expected domain data.""" + + def test_question_has_query_text(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["question"], TG_QUERY, uris["question"]) + assert t is not None + assert t.o.value == "What is quantum computing?" + + def test_question_has_timestamp(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["question"], PROV_STARTED_AT_TIME, uris["question"]) + assert t is not None + assert t.o.value == "2026-01-01T00:00:00Z" + + def test_grounding_has_concepts(self, chain): + uris = chain["uris"] + concepts = find_triples(chain["triples"]["grounding"], TG_CONCEPT, uris["grounding"]) + concept_values = {t.o.value for t in concepts} + assert concept_values == {"quantum", "computing"} + + def test_exploration_has_edge_count(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["exploration"], TG_EDGE_COUNT, uris["exploration"]) + assert t is not None + assert t.o.value == "42" + + def test_exploration_has_entities(self, chain): + uris = chain["uris"] + entities = find_triples(chain["triples"]["exploration"], TG_ENTITY, uris["exploration"]) + entity_iris = {t.o.iri for t in entities} + assert entity_iris == {"urn:entity:1", "urn:entity:2"} + + def test_focus_has_selected_edges(self, chain): + uris = chain["uris"] + edges = find_triples(chain["triples"]["focus"], TG_SELECTED_EDGE, uris["focus"]) + assert len(edges) == 2 + + def test_focus_edges_have_quoted_triples(self, chain): + """Each edge selection entity should have a tg:edge with a quoted triple.""" + focus = chain["triples"]["focus"] + edge_triples = find_triples(focus, TG_EDGE) + assert len(edge_triples) == 2 + + # Each should have a quoted triple as the object + for t in edge_triples: + assert t.o.triple is not None, "tg:edge object should be a quoted triple" + + def test_focus_edges_have_reasoning(self, chain): + """Each edge selection entity should have tg:reasoning.""" + focus = chain["triples"]["focus"] + reasoning = find_triples(focus, TG_REASONING) + assert len(reasoning) == 2 + reasoning_texts = {t.o.value for t in reasoning} + assert "Directly relevant to the query" in reasoning_texts + assert "Provides the entity label" in reasoning_texts + + def test_synthesis_has_document_ref(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["synthesis"], TG_DOCUMENT, uris["synthesis"]) + assert t is not None + assert t.o.iri == "urn:doc:answer-1" + + def test_synthesis_has_labels(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["synthesis"], RDFS_LABEL, uris["synthesis"]) + assert t is not None + assert t.o.value == "Synthesis" diff --git a/tests/unit/test_retrieval/test_document_rag_provenance_integration.py b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py new file mode 100644 index 00000000..74157285 --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py @@ -0,0 +1,380 @@ +""" +Integration test: run a full DocumentRag.query() with mocked subsidiary +clients and verify the explain_callback receives the complete provenance +chain in the correct order with correct structure. + +Document-RAG provenance chain (4 stages): + question → grounding → exploration → synthesis +""" + +import pytest +from unittest.mock import AsyncMock +from dataclasses import dataclass + +from trustgraph.retrieval.document_rag.document_rag import DocumentRag + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_SYNTHESIS, TG_ANSWER_TYPE, + TG_QUERY, TG_CONCEPT, + TG_CHUNK_COUNT, TG_SELECTED_CHUNK, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +@dataclass +class ChunkMatch: + """Mimics the result from doc_embeddings_client.query().""" + chunk_id: str + + +# --------------------------------------------------------------------------- +# Mock setup +# --------------------------------------------------------------------------- + +CHUNK_A = "urn:chunk:policy-doc-1:chunk-0" +CHUNK_B = "urn:chunk:policy-doc-1:chunk-1" +CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase." +CHUNK_B_CONTENT = "Refunds are processed to the original payment method." + + +def build_mock_clients(): + """ + Build mock clients for a document-rag query. + + Client call sequence during query(): + 1. prompt_client.prompt("extract-concepts", ...) -> concepts + 2. embeddings_client.embed(concepts) -> vectors + 3. doc_embeddings_client.query(vector, ...) -> chunk matches + 4. fetch_chunk(chunk_id, user) -> chunk content + 5. prompt_client.document_prompt(query, documents) -> answer + """ + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + doc_embeddings_client = AsyncMock() + fetch_chunk = AsyncMock() + + # 1. Concept extraction + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return "return policy\nrefund" + return "" + + prompt_client.prompt.side_effect = mock_prompt + + # 2. Embedding vectors + embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # 3. Chunk matching + doc_embeddings_client.query.return_value = [ + ChunkMatch(chunk_id=CHUNK_A), + ChunkMatch(chunk_id=CHUNK_B), + ] + + # 4. Chunk content + async def mock_fetch(chunk_id, user): + return { + CHUNK_A: CHUNK_A_CONTENT, + CHUNK_B: CHUNK_B_CONTENT, + }[chunk_id] + + fetch_chunk.side_effect = mock_fetch + + # 5. Synthesis + prompt_client.document_prompt.return_value = ( + "Items can be returned within 30 days for a full refund." + ) + + return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestDocumentRagQueryProvenance: + """ + Run a real DocumentRag.query() and verify the provenance chain emitted + via explain_callback. + """ + + @pytest.mark.asyncio + async def test_explain_callback_receives_four_events(self): + """query() should emit exactly 4 explain events.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + assert len(events) == 4, ( + f"Expected 4 explain events (question, grounding, exploration, " + f"synthesis), got {len(events)}" + ) + + @pytest.mark.asyncio + async def test_events_have_correct_types_in_order(self): + """ + Events should arrive as: + question, grounding, exploration, synthesis. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + expected_types = [ + TG_DOC_RAG_QUESTION, + TG_GROUNDING, + TG_EXPLORATION, + TG_SYNTHESIS, + ] + + for i, expected_type in enumerate(expected_types): + uri = events[i]["explain_id"] + triples = events[i]["triples"] + assert has_type(triples, uri, expected_type), ( + f"Event {i} (uri={uri}) should have type {expected_type}" + ) + + @pytest.mark.asyncio + async def test_derivation_chain_links_correctly(self): + """ + Each event's URI should link to the previous via wasDerivedFrom: + question → (none) + grounding → question + exploration → grounding + synthesis → exploration + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + uris = [e["explain_id"] for e in events] + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + + # question has no parent + assert derived_from(all_triples, uris[0]) is None + + # grounding → question + assert derived_from(all_triples, uris[1]) == uris[0] + + # exploration → grounding + assert derived_from(all_triples, uris[2]) == uris[1] + + # synthesis → exploration + assert derived_from(all_triples, uris[3]) == uris[2] + + @pytest.mark.asyncio + async def test_question_carries_query_text(self): + """The question event should contain the original query string.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + q_uri = events[0]["explain_id"] + q_triples = events[0]["triples"] + t = find_triple(q_triples, TG_QUERY, q_uri) + assert t is not None + assert t.o.value == "What is the return policy?" + + @pytest.mark.asyncio + async def test_grounding_carries_concepts(self): + """The grounding event should list extracted concepts.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + gnd_uri = events[1]["explain_id"] + gnd_triples = events[1]["triples"] + concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri) + concept_values = {t.o.value for t in concepts} + assert "return policy" in concept_values + assert "refund" in concept_values + + @pytest.mark.asyncio + async def test_exploration_has_chunk_count(self): + """The exploration event should report the number of chunks retrieved.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + exp_uri = events[2]["explain_id"] + exp_triples = events[2]["triples"] + t = find_triple(exp_triples, TG_CHUNK_COUNT, exp_uri) + assert t is not None + assert int(t.o.value) == 2 + + @pytest.mark.asyncio + async def test_exploration_has_selected_chunks(self): + """The exploration event should list the chunk IDs that were fetched.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + exp_uri = events[2]["explain_id"] + exp_triples = events[2]["triples"] + chunks = find_triples(exp_triples, TG_SELECTED_CHUNK, exp_uri) + chunk_iris = {t.o.iri for t in chunks} + assert CHUNK_A in chunk_iris + assert CHUNK_B in chunk_iris + + @pytest.mark.asyncio + async def test_synthesis_is_answer_type(self): + """The synthesis event should have tg:Synthesis and tg:Answer types.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + syn_uri = events[3]["explain_id"] + syn_triples = events[3]["triples"] + assert has_type(syn_triples, syn_uri, TG_SYNTHESIS) + assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE) + + @pytest.mark.asyncio + async def test_query_returns_answer_text(self): + """query() should return the synthesised answer.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + result = await rag.query( + query="What is the return policy?", + explain_callback=AsyncMock(), + ) + + assert result == "Items can be returned within 30 days for a full refund." + + @pytest.mark.asyncio + async def test_no_explain_callback_still_works(self): + """query() without explain_callback should return answer normally.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + result = await rag.query(query="What is the return policy?") + assert result == "Items can be returned within 30 days for a full refund." + + @pytest.mark.asyncio + async def test_all_triples_in_retrieval_graph(self): + """All emitted triples should be in the urn:graph:retrieval graph.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + for event in events: + for t in event["triples"]: + assert t.g == "urn:graph:retrieval", ( + f"Triple {t.s.iri} {t.p.iri} should be in " + f"urn:graph:retrieval, got {t.g}" + ) diff --git a/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py b/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py new file mode 100644 index 00000000..603bd204 --- /dev/null +++ b/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py @@ -0,0 +1,358 @@ +""" +Tests that explain_triples are forwarded correctly through the graph-rag +service and client layers. + +Covers: +- Service: explain messages include triples from the provenance callback +- Client: explain_callback receives explain_triples from the response +- End-to-end: triples survive the full service → client → callback chain +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.schema import ( + GraphRagQuery, GraphRagResponse, + Triple, Term, IRI, LITERAL, +) +from trustgraph.base.graph_rag_client import GraphRagClient + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_triple(s_iri, p_iri, o_value, o_type=IRI): + """Create a Triple with IRI subject/predicate and typed object.""" + o = ( + Term(type=IRI, iri=o_value) if o_type == IRI + else Term(type=LITERAL, value=o_value) + ) + return Triple( + s=Term(type=IRI, iri=s_iri), + p=Term(type=IRI, iri=p_iri), + o=o, + ) + + +def sample_focus_triples(): + """Focus-style triples with a quoted triple (edge selection).""" + return [ + make_triple( + "urn:trustgraph:focus:abc", + "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", + "https://trustgraph.ai/ns/Focus", + ), + make_triple( + "urn:trustgraph:focus:abc", + "http://www.w3.org/ns/prov#wasDerivedFrom", + "urn:trustgraph:exploration:abc", + ), + make_triple( + "urn:trustgraph:focus:abc", + "https://trustgraph.ai/ns/selectedEdge", + "urn:trustgraph:edge-sel:abc:0", + ), + ] + + +def sample_question_triples(): + """Question-style triples.""" + return [ + make_triple( + "urn:trustgraph:question:abc", + "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", + "https://trustgraph.ai/ns/GraphRagQuestion", + ), + make_triple( + "urn:trustgraph:question:abc", + "https://trustgraph.ai/ns/query", + "What is quantum computing?", + o_type=LITERAL, + ), + ] + + +# --------------------------------------------------------------------------- +# Service-level: explain messages carry triples +# --------------------------------------------------------------------------- + +class TestGraphRagServiceExplainTriples: + """Test that the graph-rag service includes explain_triples in messages.""" + + @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') + @pytest.mark.asyncio + async def test_explain_messages_include_triples(self, mock_graph_rag_class): + """ + When the provenance callback is invoked with triples, the service + should include them in the explain response message. + """ + from trustgraph.retrieval.graph_rag.rag import Processor + + processor = Processor( + taskgroup=MagicMock(), + id="test-processor", + entity_limit=50, + triple_limit=30, + max_subgraph_size=150, + max_path_length=2, + ) + + mock_rag_instance = AsyncMock() + mock_graph_rag_class.return_value = mock_rag_instance + + question_triples = sample_question_triples() + focus_triples = sample_focus_triples() + + async def mock_query(**kwargs): + explain_callback = kwargs.get('explain_callback') + if explain_callback: + await explain_callback( + question_triples, "urn:trustgraph:question:abc" + ) + await explain_callback( + focus_triples, "urn:trustgraph:focus:abc" + ) + return "The answer." + + mock_rag_instance.query.side_effect = mock_query + + msg = MagicMock() + msg.value.return_value = GraphRagQuery( + query="What is quantum computing?", + user="trustgraph", + collection="default", + streaming=False, + ) + msg.properties.return_value = {"id": "test-id"} + + consumer = MagicMock() + flow = MagicMock() + mock_response = AsyncMock() + mock_provenance = AsyncMock() + + def flow_router(name): + if name == "response": + return mock_response + if name == "explainability": + return mock_provenance + return AsyncMock() + + flow.side_effect = flow_router + + await processor.on_request(msg, consumer, flow) + + # Find the explain messages + explain_msgs = [ + call[0][0] + for call in mock_response.send.call_args_list + if call[0][0].message_type == "explain" + ] + + assert len(explain_msgs) == 2 + + # First explain message should carry question triples + assert explain_msgs[0].explain_id == "urn:trustgraph:question:abc" + assert explain_msgs[0].explain_triples == question_triples + + # Second explain message should carry focus triples + assert explain_msgs[1].explain_id == "urn:trustgraph:focus:abc" + assert explain_msgs[1].explain_triples == focus_triples + + +# --------------------------------------------------------------------------- +# Client-level: explain_callback receives triples +# --------------------------------------------------------------------------- + +class TestGraphRagClientExplainForwarding: + """Test that GraphRagClient.rag() forwards explain_triples to callback.""" + + @pytest.mark.asyncio + async def test_explain_callback_receives_triples(self): + """ + The explain_callback should receive (explain_id, explain_graph, + explain_triples) — not just (explain_id, explain_graph). + """ + focus_triples = sample_focus_triples() + + # Simulate the response sequence the client would receive + responses = [ + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:focus:abc", + explain_graph="urn:graph:retrieval", + explain_triples=focus_triples, + ), + GraphRagResponse( + message_type="chunk", + response="The answer.", + end_of_stream=True, + ), + GraphRagResponse( + message_type="chunk", + response="", + end_of_session=True, + ), + ] + + # Capture what the explain_callback receives + received_calls = [] + + async def explain_callback(explain_id, explain_graph, explain_triples): + received_calls.append({ + "explain_id": explain_id, + "explain_graph": explain_graph, + "explain_triples": explain_triples, + }) + + # Patch self.request to feed responses to the recipient + client = GraphRagClient.__new__(GraphRagClient) + + async def mock_request(req, timeout=600, recipient=None): + for resp in responses: + done = await recipient(resp) + if done: + return resp + + client.request = mock_request + + result = await client.rag( + query="test", + explain_callback=explain_callback, + ) + + assert result == "The answer." + assert len(received_calls) == 1 + assert received_calls[0]["explain_id"] == "urn:trustgraph:focus:abc" + assert received_calls[0]["explain_graph"] == "urn:graph:retrieval" + assert received_calls[0]["explain_triples"] == focus_triples + + @pytest.mark.asyncio + async def test_explain_callback_receives_empty_triples(self): + """ + When an explain event has no triples, the callback should still + receive an empty list (not None or missing). + """ + responses = [ + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc", + explain_graph="urn:graph:retrieval", + explain_triples=[], + ), + GraphRagResponse( + message_type="chunk", + response="Answer.", + end_of_stream=True, + end_of_session=True, + ), + ] + + received_calls = [] + + async def explain_callback(explain_id, explain_graph, explain_triples): + received_calls.append(explain_triples) + + client = GraphRagClient.__new__(GraphRagClient) + + async def mock_request(req, timeout=600, recipient=None): + for resp in responses: + done = await recipient(resp) + if done: + return resp + + client.request = mock_request + + await client.rag(query="test", explain_callback=explain_callback) + + assert len(received_calls) == 1 + assert received_calls[0] == [] + + @pytest.mark.asyncio + async def test_multiple_explain_events_all_forward_triples(self): + """ + Each explain event in a session should forward its own triples. + """ + q_triples = sample_question_triples() + f_triples = sample_focus_triples() + + responses = [ + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc", + explain_graph="urn:graph:retrieval", + explain_triples=q_triples, + ), + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:focus:abc", + explain_graph="urn:graph:retrieval", + explain_triples=f_triples, + ), + GraphRagResponse( + message_type="chunk", + response="Answer.", + end_of_stream=True, + end_of_session=True, + ), + ] + + received_calls = [] + + async def explain_callback(explain_id, explain_graph, explain_triples): + received_calls.append({ + "explain_id": explain_id, + "explain_triples": explain_triples, + }) + + client = GraphRagClient.__new__(GraphRagClient) + + async def mock_request(req, timeout=600, recipient=None): + for resp in responses: + done = await recipient(resp) + if done: + return resp + + client.request = mock_request + + await client.rag(query="test", explain_callback=explain_callback) + + assert len(received_calls) == 2 + assert received_calls[0]["explain_id"] == "urn:trustgraph:question:abc" + assert received_calls[0]["explain_triples"] == q_triples + assert received_calls[1]["explain_id"] == "urn:trustgraph:focus:abc" + assert received_calls[1]["explain_triples"] == f_triples + + @pytest.mark.asyncio + async def test_no_explain_callback_does_not_error(self): + """ + When no explain_callback is provided, explain events should be + silently skipped without errors. + """ + responses = [ + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc", + explain_graph="urn:graph:retrieval", + explain_triples=sample_question_triples(), + ), + GraphRagResponse( + message_type="chunk", + response="Answer.", + end_of_stream=True, + end_of_session=True, + ), + ] + + client = GraphRagClient.__new__(GraphRagClient) + + async def mock_request(req, timeout=600, recipient=None): + for resp in responses: + done = await recipient(resp) + if done: + return resp + + client.request = mock_request + + result = await client.rag(query="test") + assert result == "Answer." diff --git a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py new file mode 100644 index 00000000..36536f7d --- /dev/null +++ b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py @@ -0,0 +1,482 @@ +""" +Integration test: run a full GraphRag.query() with mocked subsidiary clients +and verify the explain_callback receives the complete provenance chain +in the correct order with correct structure. + +This tests the real query() method end-to-end, not just the triple builders. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock +from dataclasses import dataclass + +from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id +from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE, + TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT, + TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +@dataclass +class EmbeddingMatch: + """Mimics the result from graph_embeddings_client.query().""" + entity: Term + + +# --------------------------------------------------------------------------- +# Mock setup +# --------------------------------------------------------------------------- + +# A tiny knowledge graph: 2 entities, 3 edges +ENTITY_A = "http://example.com/QuantumComputing" +ENTITY_B = "http://example.com/Physics" +EDGE_1 = (ENTITY_A, "http://schema.org/relatedTo", ENTITY_B) +EDGE_2 = (ENTITY_A, "http://schema.org/name", "Quantum Computing") +EDGE_3 = (ENTITY_B, "http://schema.org/name", "Physics") + + +def make_schema_triple(s, p, o): + """Create a SchemaTriple from string values.""" + return SchemaTriple( + s=Term(type=IRI, iri=s), + p=Term(type=IRI, iri=p), + o=Term(type=IRI, iri=o) if o.startswith("http") else Term(type=LITERAL, value=o), + ) + + +def build_mock_clients(): + """ + Build mock clients that simulate a small knowledge graph query. + + Client call sequence during query(): + 1. prompt_client.prompt("extract-concepts", ...) -> concepts + 2. embeddings_client.embed(concepts) -> vectors + 3. graph_embeddings_client.query(vector, ...) -> entity matches + 4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch) + 5. triples_client.query(s, LABEL, ...) -> labels (maybe_label) + 6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges + 7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning + 8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns []) + 9. prompt_client.prompt("kg-synthesis", ...) -> final answer + """ + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + graph_embeddings_client = AsyncMock() + triples_client = AsyncMock() + + # 1. Concept extraction + prompt_responses = {} + prompt_responses["extract-concepts"] = "quantum computing\nphysics" + + # 2. Embedding vectors (simple fake vectors) + embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # 3. Entity lookup - return our two entities + graph_embeddings_client.query.return_value = [ + EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_A)), + EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)), + ] + + # 4. Triple queries (follow_edges_batch) - return our edges + kg_triples = [ + make_schema_triple(*EDGE_1), + make_schema_triple(*EDGE_2), + make_schema_triple(*EDGE_3), + ] + triples_client.query_stream.return_value = kg_triples + + # 5. Label resolution - return entity as its own label (simplify) + async def mock_label_query(s=None, p=None, o=None, limit=1, + user=None, collection=None, g=None): + return [] # No labels found, will fall back to URI + triples_client.query.side_effect = mock_label_query + + # 6+7. Edge scoring and reasoning: dynamically score/reason about + # whatever edges the query method sends us, since edge IDs are computed + # from str(Term) representations which include the full dataclass repr. + synthesis_answer = "Quantum computing applies physics principles to computation." + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return prompt_responses["extract-concepts"] + elif template_id == "kg-edge-scoring": + # Score all edges highly, using the IDs that GraphRag computed + edges = variables.get("knowledge", []) + return [ + {"id": e["id"], "score": 10 - i} + for i, e in enumerate(edges) + ] + elif template_id == "kg-edge-reasoning": + # Provide reasoning for each edge + edges = variables.get("knowledge", []) + return [ + {"id": e["id"], "reasoning": f"Relevant edge {i}"} + for i, e in enumerate(edges) + ] + elif template_id == "kg-synthesis": + return synthesis_answer + return "" + + prompt_client.prompt.side_effect = mock_prompt + + return prompt_client, embeddings_client, graph_embeddings_client, triples_client + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestGraphRagQueryProvenance: + """ + Run a real GraphRag.query() and verify the provenance chain emitted + via explain_callback. + """ + + @pytest.mark.asyncio + async def test_explain_callback_receives_five_events(self): + """query() should emit exactly 5 explain events.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, # skip semantic pre-filter for simplicity + ) + + assert len(events) == 5, ( + f"Expected 5 explain events (question, grounding, exploration, " + f"focus, synthesis), got {len(events)}" + ) + + @pytest.mark.asyncio + async def test_events_have_correct_types_in_order(self): + """ + Events should arrive as: + question, grounding, exploration, focus, synthesis. + """ + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + expected_types = [ + TG_GRAPH_RAG_QUESTION, + TG_GROUNDING, + TG_EXPLORATION, + TG_FOCUS, + TG_SYNTHESIS, + ] + + for i, expected_type in enumerate(expected_types): + uri = events[i]["explain_id"] + triples = events[i]["triples"] + assert has_type(triples, uri, expected_type), ( + f"Event {i} (uri={uri}) should have type {expected_type}" + ) + + @pytest.mark.asyncio + async def test_derivation_chain_links_correctly(self): + """ + Each event's URI should link to the previous via wasDerivedFrom: + grounding → question → (none) + exploration → grounding + focus → exploration + synthesis → focus + """ + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + uris = [e["explain_id"] for e in events] + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + + # question has no parent + assert derived_from(all_triples, uris[0]) is None + + # grounding → question + assert derived_from(all_triples, uris[1]) == uris[0] + + # exploration → grounding + assert derived_from(all_triples, uris[2]) == uris[1] + + # focus → exploration + assert derived_from(all_triples, uris[3]) == uris[2] + + # synthesis → focus + assert derived_from(all_triples, uris[4]) == uris[3] + + @pytest.mark.asyncio + async def test_question_event_carries_query_text(self): + """The question event should contain the original query string.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + q_uri = events[0]["explain_id"] + q_triples = events[0]["triples"] + t = find_triple(q_triples, TG_QUERY, q_uri) + assert t is not None + assert t.o.value == "What is quantum computing?" + + @pytest.mark.asyncio + async def test_grounding_carries_concepts(self): + """The grounding event should list extracted concepts.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + gnd_uri = events[1]["explain_id"] + gnd_triples = events[1]["triples"] + concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri) + concept_values = {t.o.value for t in concepts} + assert "quantum computing" in concept_values + assert "physics" in concept_values + + @pytest.mark.asyncio + async def test_exploration_has_edge_count(self): + """The exploration event should report how many edges were found.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + exp_uri = events[2]["explain_id"] + exp_triples = events[2]["triples"] + t = find_triple(exp_triples, TG_EDGE_COUNT, exp_uri) + assert t is not None + # Should be non-zero (we provided 3 edges, label edges filtered) + assert int(t.o.value) > 0 + + @pytest.mark.asyncio + async def test_focus_has_selected_edges_with_reasoning(self): + """ + The focus event should carry selected edges as quoted triples + with reasoning text. + """ + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + foc_uri = events[3]["explain_id"] + foc_triples = events[3]["triples"] + + # Should have selected edges + selected = find_triples(foc_triples, TG_SELECTED_EDGE, foc_uri) + assert len(selected) > 0, "Focus should have at least one selected edge" + + # Each edge selection should have a quoted triple + edge_t = find_triples(foc_triples, TG_EDGE) + assert len(edge_t) > 0, "Focus should have tg:edge with quoted triples" + for t in edge_t: + assert t.o.triple is not None, "tg:edge object must be a quoted triple" + + # Should have reasoning + reasoning = find_triples(foc_triples, TG_REASONING) + assert len(reasoning) > 0, "Focus should have reasoning for selected edges" + reasoning_texts = {t.o.value for t in reasoning} + assert any(r for r in reasoning_texts), "Reasoning should not be empty" + + @pytest.mark.asyncio + async def test_synthesis_is_answer_type(self): + """The synthesis event should have tg:Answer type.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + syn_uri = events[4]["explain_id"] + syn_triples = events[4]["triples"] + assert has_type(syn_triples, syn_uri, TG_SYNTHESIS) + assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE) + + @pytest.mark.asyncio + async def test_query_returns_answer_text(self): + """query() should still return the synthesised answer.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + result = await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + assert result == "Quantum computing applies physics principles to computation." + + @pytest.mark.asyncio + async def test_parent_uri_links_question_to_parent(self): + """When parent_uri is provided, question should derive from it.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + parent = "urn:trustgraph:agent:iteration:xyz" + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + parent_uri=parent, + ) + + q_uri = events[0]["explain_id"] + q_triples = events[0]["triples"] + assert derived_from(q_triples, q_uri) == parent + + @pytest.mark.asyncio + async def test_no_explain_callback_still_works(self): + """query() without explain_callback should return answer normally.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + result = await rag.query( + query="What is quantum computing?", + edge_score_limit=0, + ) + + assert result == "Quantum computing applies physics principles to computation." + + @pytest.mark.asyncio + async def test_all_triples_in_retrieval_graph(self): + """All emitted triples should be in the urn:graph:retrieval graph.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + for event in events: + for t in event["triples"]: + assert t.g == "urn:graph:retrieval", ( + f"Triple {t.s.iri} {t.p.iri} should be in " + f"urn:graph:retrieval, got {t.g}" + ) diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py index 32007943..9db23293 100644 --- a/trustgraph-base/trustgraph/base/graph_rag_client.py +++ b/trustgraph-base/trustgraph/base/graph_rag_client.py @@ -15,7 +15,7 @@ class GraphRagClient(RequestResponse): user: User identifier collection: Collection identifier chunk_callback: Optional async callback(text, end_of_stream) for text chunks - explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications + explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications timeout: Request timeout in seconds Returns: @@ -30,7 +30,7 @@ class GraphRagClient(RequestResponse): # Handle explain notifications if resp.message_type == 'explain': if explain_callback and resp.explain_id: - await explain_callback(resp.explain_id, resp.explain_graph) + await explain_callback(resp.explain_id, resp.explain_graph, resp.explain_triples) return False # Continue receiving # Handle text chunks diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 057376fb..365ea09d 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -43,7 +43,7 @@ class DocumentRagClient(BaseClient): user: User identifier collection: Collection identifier chunk_callback: Optional callback(text, end_of_stream) for text chunks - explain_callback: Optional callback(explain_id, explain_graph) for explain notifications + explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications timeout: Request timeout in seconds Returns: @@ -55,7 +55,7 @@ class DocumentRagClient(BaseClient): # Handle explain notifications (response is None/empty, explain_id present) if x.explain_id and not x.response: if explain_callback: - explain_callback(x.explain_id, x.explain_graph) + explain_callback(x.explain_id, x.explain_graph, x.explain_triples) return False # Continue receiving # Handle text chunks diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index 17d7b0f0..0d33bf91 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -47,7 +47,7 @@ class GraphRagClient(BaseClient): user: User identifier collection: Collection identifier chunk_callback: Optional callback(text, end_of_stream) for text chunks - explain_callback: Optional callback(explain_id, explain_graph) for explain notifications + explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications timeout: Request timeout in seconds Returns: @@ -59,7 +59,7 @@ class GraphRagClient(BaseClient): # Handle explain notifications if x.message_type == 'explain': if explain_callback and x.explain_id: - explain_callback(x.explain_id, x.explain_graph) + explain_callback(x.explain_id, x.explain_graph, x.explain_triples) return False # Continue receiving # Handle text chunks diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 041558ec..6fd96ade 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -39,13 +39,14 @@ class KnowledgeQueryImpl: if respond: from ... schema import AgentResponse - async def explain_callback(explain_id, explain_graph): + async def explain_callback(explain_id, explain_graph, explain_triples=None): self.context.last_sub_explain_uri = explain_id await respond(AgentResponse( chunk_type="explain", content="", explain_id=explain_id, explain_graph=explain_graph, + explain_triples=explain_triples or [], )) if current_uri: