Forward missing explain_triples through RAG clients and agent tool callback (#768)

fix: forward explain_triples through RAG clients and agent tool callback
- RAG clients and the KnowledgeQueryImpl tool callback were
  dropping explain_triples from explain events, losing provenance
  data (including focus edge selections) when graph-rag is invoked
  via the agent.

Tests for provenance and explainability (56 new):
- Client-level forwarding of explain_triples
- Graph-RAG structural chain
  (question → grounding → exploration → focus → synthesis)
- Graph-RAG integration with mocked subsidiary clients
- Document-RAG integration
  (question → grounding → exploration → synthesis)
- Agent-orchestrator all 3 patterns: react, plan-then-execute,
  supervisor
This commit is contained in:
cybermaggedon 2026-04-08 11:41:17 +01:00 committed by GitHub
parent e899370d98
commit 4b5bfacab1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 2178 additions and 7 deletions

View file

@ -0,0 +1,655 @@
"""
Integration tests for agent-orchestrator provenance chains.
Tests all three patterns by calling iterate() with mocked dependencies
and verifying the explain events emitted via respond().
Provenance chains:
React: session iteration (observation or final)
Plan: session plan step-result(s) synthesis
Supervisor: session decomposition finding(s) synthesis
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
GRAPH_RETRIEVAL,
)
# Agent provenance type constants
from trustgraph.provenance.namespaces import (
TG_AGENT_QUESTION,
TG_ANALYSIS,
TG_TOOL_USE,
TG_OBSERVATION_TYPE,
TG_CONCLUSION,
TG_DECOMPOSITION,
TG_FINDING,
TG_PLAN_TYPE,
TG_STEP_RESULT,
TG_SYNTHESIS as TG_AGENT_SYNTHESIS,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
def collect_explain_events(respond_mock):
"""Extract explain events from a respond mock's call history."""
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
events.append({
"explain_id": resp.explain_id,
"explain_graph": resp.explain_graph,
"triples": resp.explain_triples,
})
return events
# ---------------------------------------------------------------------------
# Mock processor
# ---------------------------------------------------------------------------
def make_mock_processor(tools=None):
"""Build a mock processor with the minimal interface patterns need."""
processor = MagicMock()
processor.max_iterations = 10
processor.save_answer_content = AsyncMock()
# provenance_session_uri must return a real URI
def mock_session_uri(session_id):
return f"urn:trustgraph:agent:session:{session_id}"
processor.provenance_session_uri.side_effect = mock_session_uri
# Agent with tools
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agent = agent
# Aggregator for supervisor
processor.aggregator = MagicMock()
return processor
def make_mock_flow():
"""Build a mock flow that returns async mock producers."""
producers = {}
def flow_factory(name):
if name not in producers:
producers[name] = AsyncMock()
return producers[name]
flow = MagicMock(side_effect=flow_factory)
flow._producers = producers
return flow
def make_base_request(**kwargs):
"""Build a minimal AgentRequest."""
defaults = dict(
question="What is quantum computing?",
state="",
group=[],
history=[],
user="testuser",
collection="default",
streaming=False,
session_id="test-session-123",
conversation_id="",
pattern="react",
task_type="",
framing="",
correlation_id="",
parent_session_id="",
subagent_goal="",
expected_siblings=0,
)
defaults.update(kwargs)
return AgentRequest(**defaults)
# ---------------------------------------------------------------------------
# React pattern tests
# ---------------------------------------------------------------------------
class TestReactPatternProvenance:
"""
React pattern chain: session iteration final
(single iteration ending in Final answer)
"""
@pytest.mark.asyncio
async def test_single_iteration_final_answer(self):
"""
A single react iteration that produces a Final answer should emit:
session, iteration, final in that order.
"""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action, Final
processor = make_mock_processor()
pattern = ReactPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
request = make_base_request()
# Mock AgentManager.react to call on_action then return Final
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
final = Final(
thought="I know the answer",
final="Quantum computing uses qubits.",
)
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
thought="I know the answer",
name="final",
arguments={},
observation="",
))
return final
mock_am.react.side_effect = mock_react
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 3 events: session, iteration, final
assert len(events) == 3, (
f"Expected 3 explain events (session, iteration, final), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
# Check types
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_CONCLUSION)
# Check derivation chain
all_triples = []
for e in events:
all_triples.extend(e["triples"])
uris = [e["explain_id"] for e in events]
# iteration derives from session
assert derived_from(all_triples, uris[1]) == uris[0]
# final derives from session (first iteration, no prior observation)
assert derived_from(all_triples, uris[2]) == uris[0]
@pytest.mark.asyncio
async def test_iteration_with_tool_call(self):
"""
A react iteration that calls a tool (not Final) should emit:
session, iteration, observation then call next() for continuation.
"""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action
# Create a mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query the knowledge base"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="The answer is 42")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = make_mock_processor(
tools={"knowledge-query": mock_tool}
)
pattern = ReactPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
request = make_base_request()
action = Action(
thought="I need to look this up",
name="knowledge-query",
arguments={"question": "What is quantum computing?"},
observation="Quantum computing uses qubits.",
)
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
if on_action:
await on_action(action)
return action
mock_am.react.side_effect = mock_react
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 3 events: session, iteration, observation
assert len(events) == 3, (
f"Expected 3 explain events (session, iteration, observation), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_OBSERVATION_TYPE)
# next() should have been called to continue the loop
assert next_fn.called
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All explain triples should be in urn:graph:retrieval."""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action, Final
processor = make_mock_processor()
pattern = ReactPattern(processor)
respond = AsyncMock()
flow = make_mock_flow()
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
if on_action:
await on_action(Action(
thought="done", name="final",
arguments={}, observation="",
))
return Final(thought="done", final="answer")
mock_am.react.side_effect = mock_react
await pattern.iterate(
make_base_request(), respond, AsyncMock(), flow,
)
for event in collect_explain_events(respond):
for t in event["triples"]:
assert t.g == GRAPH_RETRIEVAL
# ---------------------------------------------------------------------------
# Plan-then-execute pattern tests
# ---------------------------------------------------------------------------
class TestPlanPatternProvenance:
"""
Plan pattern chain:
Planning iteration: session plan
Execution iterations: step-result(s) synthesis
"""
@pytest.mark.asyncio
async def test_planning_iteration_emits_session_and_plan(self):
"""
The first iteration (planning) should emit:
session, plan then call next() with the plan in history.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
processor = make_mock_processor()
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="plan")
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 2 events: session, plan
assert len(events) == 2, (
f"Expected 2 explain events (session, plan), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_PLAN_TYPE)
# Plan should derive from session
all_triples = []
for e in events:
all_triples.extend(e["triples"])
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
# next() should have been called with plan in history
assert next_fn.called
@pytest.mark.asyncio
async def test_execution_iteration_emits_step_result(self):
"""
An execution iteration should emit a step-result event.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
# Create a mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query KB"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="Found the answer")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = make_mock_processor(
tools={"knowledge-query": mock_tool}
)
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = {
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
}
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with plan already in history (second iteration)
plan_step = AgentStep(
thought="Created plan",
action="plan",
arguments={},
observation="[]",
step_type="plan",
plan=[
PlanStep(goal="Find info", tool_hint="knowledge-query",
depends_on=[], status="pending", result=""),
],
)
request = make_base_request(
pattern="plan",
history=[plan_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have step-result (no session on iteration > 1)
step_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_STEP_RESULT)
]
assert len(step_events) == 1, (
f"Expected 1 step-result event, got {len(step_events)}"
)
@pytest.mark.asyncio
async def test_synthesis_after_all_steps_complete(self):
"""
When all plan steps are completed, synthesis should be emitted.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
processor = make_mock_processor()
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer."
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with all steps completed
exec_step = AgentStep(
thought="Executing step",
action="knowledge-query",
arguments={},
observation="Result",
step_type="execute",
plan=[
PlanStep(goal="Find info", tool_hint="knowledge-query",
depends_on=[], status="completed", result="Found it"),
],
)
request = make_base_request(
pattern="plan",
history=[exec_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have synthesis event
synth_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
]
assert len(synth_events) == 1, (
f"Expected 1 synthesis event, got {len(synth_events)}"
)
# ---------------------------------------------------------------------------
# Supervisor pattern tests
# ---------------------------------------------------------------------------
class TestSupervisorPatternProvenance:
"""
Supervisor pattern chain:
Decompose: session decomposition
(Fan-out to subagents happens externally)
Synthesise: synthesis (derives from findings)
"""
@pytest.mark.asyncio
async def test_decompose_emits_session_and_decomposition(self):
"""
The decompose phase should emit: session, decomposition.
"""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
"What is quantum computing?",
"What are qubits?",
]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="supervisor")
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 2 events: session, decomposition
assert len(events) == 2, (
f"Expected 2 explain events (session, decomposition), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_DECOMPOSITION)
# Decomposition derives from session
all_triples = []
for e in events:
all_triples.extend(e["triples"])
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
@pytest.mark.asyncio
async def test_synthesis_emits_after_subagent_results(self):
"""
When subagent results arrive, synthesis should be emitted.
"""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer."
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with subagent results in history
synth_step = AgentStep(
thought="",
action="synthesise",
arguments={},
observation="",
step_type="synthesise",
subagent_results={
"What is quantum computing?": "It uses qubits",
"What are qubits?": "Quantum bits",
},
)
request = make_base_request(
pattern="supervisor",
history=[synth_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have synthesis event (no session on iteration > 1)
synth_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
]
assert len(synth_events) == 1
@pytest.mark.asyncio
async def test_decompose_fans_out_subagents(self):
"""The decompose phase should call next() for each subagent goal."""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="supervisor")
await pattern.iterate(request, respond, next_fn, flow)
# 3 subagent requests fanned out
assert next_fn.call_count == 3

View file

@ -0,0 +1,295 @@
"""
Structural test for the graph-rag provenance chain.
Verifies that a complete graph-rag query produces the expected
provenance chain:
question grounding exploration focus synthesis
Each step must:
- Have the correct rdf:type
- Link to its predecessor via prov:wasDerivedFrom
- Carry expected domain-specific data
"""
import pytest
from trustgraph.provenance.triples import (
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
)
from trustgraph.provenance.uris import (
question_uri,
grounding_uri,
exploration_uri,
focus_uri,
synthesis_uri,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_GRAPH_RAG_QUESTION, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT,
PROV_STARTED_AT_TIME,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SESSION_ID = "test-session-1234"
def find_triple(triples, predicate, subject=None):
"""Find first triple matching predicate (and optionally subject)."""
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
"""Find all triples matching predicate (and optionally subject)."""
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
"""Check if subject has the given rdf:type."""
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
"""Get the wasDerivedFrom target URI for a subject."""
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
# ---------------------------------------------------------------------------
# Build the full chain
# ---------------------------------------------------------------------------
@pytest.fixture
def chain():
"""Build all provenance triples for a complete graph-rag query."""
q_uri = question_uri(SESSION_ID)
gnd_uri = grounding_uri(SESSION_ID)
exp_uri = exploration_uri(SESSION_ID)
foc_uri = focus_uri(SESSION_ID)
syn_uri = synthesis_uri(SESSION_ID)
q = question_triples(q_uri, "What is quantum computing?", "2026-01-01T00:00:00Z")
gnd = grounding_triples(gnd_uri, q_uri, ["quantum", "computing"])
exp = exploration_triples(
exp_uri, gnd_uri, edge_count=42,
entities=["urn:entity:1", "urn:entity:2"],
)
foc = focus_triples(
foc_uri, exp_uri,
selected_edges_with_reasoning=[
{
"edge": (
"http://example.com/QuantumComputing",
"http://schema.org/relatedTo",
"http://example.com/Physics",
),
"reasoning": "Directly relevant to the query",
},
{
"edge": (
"http://example.com/QuantumComputing",
"http://schema.org/name",
"Quantum Computing",
),
"reasoning": "Provides the entity label",
},
],
session_id=SESSION_ID,
)
syn = synthesis_triples(syn_uri, foc_uri, document_id="urn:doc:answer-1")
return {
"uris": {
"question": q_uri,
"grounding": gnd_uri,
"exploration": exp_uri,
"focus": foc_uri,
"synthesis": syn_uri,
},
"triples": {
"question": q,
"grounding": gnd,
"exploration": exp,
"focus": foc,
"synthesis": syn,
},
"all": q + gnd + exp + foc + syn,
}
# ---------------------------------------------------------------------------
# Chain structure tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceChain:
"""Verify the full question → grounding → exploration → focus → synthesis chain."""
def test_chain_has_five_stages(self, chain):
"""Each stage should produce at least some triples."""
for stage in ["question", "grounding", "exploration", "focus", "synthesis"]:
assert len(chain["triples"][stage]) > 0, f"{stage} produced no triples"
def test_derivation_chain(self, chain):
"""
The wasDerivedFrom links must form:
grounding question, exploration grounding,
focus exploration, synthesis focus.
"""
uris = chain["uris"]
all_triples = chain["all"]
assert derived_from(all_triples, uris["grounding"]) == uris["question"]
assert derived_from(all_triples, uris["exploration"]) == uris["grounding"]
assert derived_from(all_triples, uris["focus"]) == uris["exploration"]
assert derived_from(all_triples, uris["synthesis"]) == uris["focus"]
def test_question_has_no_parent(self, chain):
"""The root question should not derive from anything (no parent_uri)."""
uris = chain["uris"]
all_triples = chain["all"]
assert derived_from(all_triples, uris["question"]) is None
def test_question_with_parent(self):
"""When a parent_uri is given, question should derive from it."""
q_uri = question_uri("child-session")
parent = "urn:trustgraph:agent:iteration:parent"
q = question_triples(q_uri, "sub-query", "2026-01-01T00:00:00Z",
parent_uri=parent)
assert derived_from(q, q_uri) == parent
# ---------------------------------------------------------------------------
# Type annotation tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceTypes:
"""Each stage must have the correct rdf:type annotations."""
def test_question_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["question"]
assert has_type(triples, uris["question"], PROV_ENTITY)
assert has_type(triples, uris["question"], TG_GRAPH_RAG_QUESTION)
def test_grounding_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["grounding"]
assert has_type(triples, uris["grounding"], PROV_ENTITY)
assert has_type(triples, uris["grounding"], TG_GROUNDING)
def test_exploration_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["exploration"]
assert has_type(triples, uris["exploration"], PROV_ENTITY)
assert has_type(triples, uris["exploration"], TG_EXPLORATION)
def test_focus_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["focus"]
assert has_type(triples, uris["focus"], PROV_ENTITY)
assert has_type(triples, uris["focus"], TG_FOCUS)
def test_synthesis_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["synthesis"]
assert has_type(triples, uris["synthesis"], PROV_ENTITY)
assert has_type(triples, uris["synthesis"], TG_SYNTHESIS)
assert has_type(triples, uris["synthesis"], TG_ANSWER_TYPE)
# ---------------------------------------------------------------------------
# Domain-specific content tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceContent:
"""Each stage should carry the expected domain data."""
def test_question_has_query_text(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["question"], TG_QUERY, uris["question"])
assert t is not None
assert t.o.value == "What is quantum computing?"
def test_question_has_timestamp(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["question"], PROV_STARTED_AT_TIME, uris["question"])
assert t is not None
assert t.o.value == "2026-01-01T00:00:00Z"
def test_grounding_has_concepts(self, chain):
uris = chain["uris"]
concepts = find_triples(chain["triples"]["grounding"], TG_CONCEPT, uris["grounding"])
concept_values = {t.o.value for t in concepts}
assert concept_values == {"quantum", "computing"}
def test_exploration_has_edge_count(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["exploration"], TG_EDGE_COUNT, uris["exploration"])
assert t is not None
assert t.o.value == "42"
def test_exploration_has_entities(self, chain):
uris = chain["uris"]
entities = find_triples(chain["triples"]["exploration"], TG_ENTITY, uris["exploration"])
entity_iris = {t.o.iri for t in entities}
assert entity_iris == {"urn:entity:1", "urn:entity:2"}
def test_focus_has_selected_edges(self, chain):
uris = chain["uris"]
edges = find_triples(chain["triples"]["focus"], TG_SELECTED_EDGE, uris["focus"])
assert len(edges) == 2
def test_focus_edges_have_quoted_triples(self, chain):
"""Each edge selection entity should have a tg:edge with a quoted triple."""
focus = chain["triples"]["focus"]
edge_triples = find_triples(focus, TG_EDGE)
assert len(edge_triples) == 2
# Each should have a quoted triple as the object
for t in edge_triples:
assert t.o.triple is not None, "tg:edge object should be a quoted triple"
def test_focus_edges_have_reasoning(self, chain):
"""Each edge selection entity should have tg:reasoning."""
focus = chain["triples"]["focus"]
reasoning = find_triples(focus, TG_REASONING)
assert len(reasoning) == 2
reasoning_texts = {t.o.value for t in reasoning}
assert "Directly relevant to the query" in reasoning_texts
assert "Provides the entity label" in reasoning_texts
def test_synthesis_has_document_ref(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["synthesis"], TG_DOCUMENT, uris["synthesis"])
assert t is not None
assert t.o.iri == "urn:doc:answer-1"
def test_synthesis_has_labels(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["synthesis"], RDFS_LABEL, uris["synthesis"])
assert t is not None
assert t.o.value == "Synthesis"

View file

@ -0,0 +1,380 @@
"""
Integration test: run a full DocumentRag.query() with mocked subsidiary
clients and verify the explain_callback receives the complete provenance
chain in the correct order with correct structure.
Document-RAG provenance chain (4 stages):
question grounding exploration synthesis
"""
import pytest
from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT,
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class ChunkMatch:
"""Mimics the result from doc_embeddings_client.query()."""
chunk_id: str
# ---------------------------------------------------------------------------
# Mock setup
# ---------------------------------------------------------------------------
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
CHUNK_B_CONTENT = "Refunds are processed to the original payment method."
def build_mock_clients():
"""
Build mock clients for a document-rag query.
Client call sequence during query():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. doc_embeddings_client.query(vector, ...) -> chunk matches
4. fetch_chunk(chunk_id, user) -> chunk content
5. prompt_client.document_prompt(query, documents) -> answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock()
# 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return "return policy\nrefund"
return ""
prompt_client.prompt.side_effect = mock_prompt
# 2. Embedding vectors
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# 3. Chunk matching
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id=CHUNK_A),
ChunkMatch(chunk_id=CHUNK_B),
]
# 4. Chunk content
async def mock_fetch(chunk_id, user):
return {
CHUNK_A: CHUNK_A_CONTENT,
CHUNK_B: CHUNK_B_CONTENT,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
# 5. Synthesis
prompt_client.document_prompt.return_value = (
"Items can be returned within 30 days for a full refund."
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestDocumentRagQueryProvenance:
"""
Run a real DocumentRag.query() and verify the provenance chain emitted
via explain_callback.
"""
@pytest.mark.asyncio
async def test_explain_callback_receives_four_events(self):
"""query() should emit exactly 4 explain events."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
assert len(events) == 4, (
f"Expected 4 explain events (question, grounding, exploration, "
f"synthesis), got {len(events)}"
)
@pytest.mark.asyncio
async def test_events_have_correct_types_in_order(self):
"""
Events should arrive as:
question, grounding, exploration, synthesis.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
expected_types = [
TG_DOC_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_SYNTHESIS,
]
for i, expected_type in enumerate(expected_types):
uri = events[i]["explain_id"]
triples = events[i]["triples"]
assert has_type(triples, uri, expected_type), (
f"Event {i} (uri={uri}) should have type {expected_type}"
)
@pytest.mark.asyncio
async def test_derivation_chain_links_correctly(self):
"""
Each event's URI should link to the previous via wasDerivedFrom:
question (none)
grounding question
exploration grounding
synthesis exploration
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
uris = [e["explain_id"] for e in events]
all_triples = []
for e in events:
all_triples.extend(e["triples"])
# question has no parent
assert derived_from(all_triples, uris[0]) is None
# grounding → question
assert derived_from(all_triples, uris[1]) == uris[0]
# exploration → grounding
assert derived_from(all_triples, uris[2]) == uris[1]
# synthesis → exploration
assert derived_from(all_triples, uris[3]) == uris[2]
@pytest.mark.asyncio
async def test_question_carries_query_text(self):
"""The question event should contain the original query string."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
t = find_triple(q_triples, TG_QUERY, q_uri)
assert t is not None
assert t.o.value == "What is the return policy?"
@pytest.mark.asyncio
async def test_grounding_carries_concepts(self):
"""The grounding event should list extracted concepts."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
gnd_uri = events[1]["explain_id"]
gnd_triples = events[1]["triples"]
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
concept_values = {t.o.value for t in concepts}
assert "return policy" in concept_values
assert "refund" in concept_values
@pytest.mark.asyncio
async def test_exploration_has_chunk_count(self):
"""The exploration event should report the number of chunks retrieved."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
t = find_triple(exp_triples, TG_CHUNK_COUNT, exp_uri)
assert t is not None
assert int(t.o.value) == 2
@pytest.mark.asyncio
async def test_exploration_has_selected_chunks(self):
"""The exploration event should list the chunk IDs that were fetched."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
chunks = find_triples(exp_triples, TG_SELECTED_CHUNK, exp_uri)
chunk_iris = {t.o.iri for t in chunks}
assert CHUNK_A in chunk_iris
assert CHUNK_B in chunk_iris
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
"""The synthesis event should have tg:Synthesis and tg:Answer types."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
syn_uri = events[3]["explain_id"]
syn_triples = events[3]["triples"]
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
@pytest.mark.asyncio
async def test_query_returns_answer_text(self):
"""query() should return the synthesised answer."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(
query="What is the return policy?",
explain_callback=AsyncMock(),
)
assert result == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
"""query() without explain_callback should return answer normally."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(query="What is the return policy?")
assert result == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All emitted triples should be in the urn:graph:retrieval graph."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
for event in events:
for t in event["triples"]:
assert t.g == "urn:graph:retrieval", (
f"Triple {t.s.iri} {t.p.iri} should be in "
f"urn:graph:retrieval, got {t.g}"
)

View file

@ -0,0 +1,358 @@
"""
Tests that explain_triples are forwarded correctly through the graph-rag
service and client layers.
Covers:
- Service: explain messages include triples from the provenance callback
- Client: explain_callback receives explain_triples from the response
- End-to-end: triples survive the full service client callback chain
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.schema import (
GraphRagQuery, GraphRagResponse,
Triple, Term, IRI, LITERAL,
)
from trustgraph.base.graph_rag_client import GraphRagClient
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_triple(s_iri, p_iri, o_value, o_type=IRI):
"""Create a Triple with IRI subject/predicate and typed object."""
o = (
Term(type=IRI, iri=o_value) if o_type == IRI
else Term(type=LITERAL, value=o_value)
)
return Triple(
s=Term(type=IRI, iri=s_iri),
p=Term(type=IRI, iri=p_iri),
o=o,
)
def sample_focus_triples():
"""Focus-style triples with a quoted triple (edge selection)."""
return [
make_triple(
"urn:trustgraph:focus:abc",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"https://trustgraph.ai/ns/Focus",
),
make_triple(
"urn:trustgraph:focus:abc",
"http://www.w3.org/ns/prov#wasDerivedFrom",
"urn:trustgraph:exploration:abc",
),
make_triple(
"urn:trustgraph:focus:abc",
"https://trustgraph.ai/ns/selectedEdge",
"urn:trustgraph:edge-sel:abc:0",
),
]
def sample_question_triples():
"""Question-style triples."""
return [
make_triple(
"urn:trustgraph:question:abc",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"https://trustgraph.ai/ns/GraphRagQuestion",
),
make_triple(
"urn:trustgraph:question:abc",
"https://trustgraph.ai/ns/query",
"What is quantum computing?",
o_type=LITERAL,
),
]
# ---------------------------------------------------------------------------
# Service-level: explain messages carry triples
# ---------------------------------------------------------------------------
class TestGraphRagServiceExplainTriples:
"""Test that the graph-rag service includes explain_triples in messages."""
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
@pytest.mark.asyncio
async def test_explain_messages_include_triples(self, mock_graph_rag_class):
"""
When the provenance callback is invoked with triples, the service
should include them in the explain response message.
"""
from trustgraph.retrieval.graph_rag.rag import Processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
entity_limit=50,
triple_limit=30,
max_subgraph_size=150,
max_path_length=2,
)
mock_rag_instance = AsyncMock()
mock_graph_rag_class.return_value = mock_rag_instance
question_triples = sample_question_triples()
focus_triples = sample_focus_triples()
async def mock_query(**kwargs):
explain_callback = kwargs.get('explain_callback')
if explain_callback:
await explain_callback(
question_triples, "urn:trustgraph:question:abc"
)
await explain_callback(
focus_triples, "urn:trustgraph:focus:abc"
)
return "The answer."
mock_rag_instance.query.side_effect = mock_query
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is quantum computing?",
user="trustgraph",
collection="default",
streaming=False,
)
msg.properties.return_value = {"id": "test-id"}
consumer = MagicMock()
flow = MagicMock()
mock_response = AsyncMock()
mock_provenance = AsyncMock()
def flow_router(name):
if name == "response":
return mock_response
if name == "explainability":
return mock_provenance
return AsyncMock()
flow.side_effect = flow_router
await processor.on_request(msg, consumer, flow)
# Find the explain messages
explain_msgs = [
call[0][0]
for call in mock_response.send.call_args_list
if call[0][0].message_type == "explain"
]
assert len(explain_msgs) == 2
# First explain message should carry question triples
assert explain_msgs[0].explain_id == "urn:trustgraph:question:abc"
assert explain_msgs[0].explain_triples == question_triples
# Second explain message should carry focus triples
assert explain_msgs[1].explain_id == "urn:trustgraph:focus:abc"
assert explain_msgs[1].explain_triples == focus_triples
# ---------------------------------------------------------------------------
# Client-level: explain_callback receives triples
# ---------------------------------------------------------------------------
class TestGraphRagClientExplainForwarding:
"""Test that GraphRagClient.rag() forwards explain_triples to callback."""
@pytest.mark.asyncio
async def test_explain_callback_receives_triples(self):
"""
The explain_callback should receive (explain_id, explain_graph,
explain_triples) not just (explain_id, explain_graph).
"""
focus_triples = sample_focus_triples()
# Simulate the response sequence the client would receive
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:focus:abc",
explain_graph="urn:graph:retrieval",
explain_triples=focus_triples,
),
GraphRagResponse(
message_type="chunk",
response="The answer.",
end_of_stream=True,
),
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
),
]
# Capture what the explain_callback receives
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append({
"explain_id": explain_id,
"explain_graph": explain_graph,
"explain_triples": explain_triples,
})
# Patch self.request to feed responses to the recipient
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
result = await client.rag(
query="test",
explain_callback=explain_callback,
)
assert result == "The answer."
assert len(received_calls) == 1
assert received_calls[0]["explain_id"] == "urn:trustgraph:focus:abc"
assert received_calls[0]["explain_graph"] == "urn:graph:retrieval"
assert received_calls[0]["explain_triples"] == focus_triples
@pytest.mark.asyncio
async def test_explain_callback_receives_empty_triples(self):
"""
When an explain event has no triples, the callback should still
receive an empty list (not None or missing).
"""
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=[],
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append(explain_triples)
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
await client.rag(query="test", explain_callback=explain_callback)
assert len(received_calls) == 1
assert received_calls[0] == []
@pytest.mark.asyncio
async def test_multiple_explain_events_all_forward_triples(self):
"""
Each explain event in a session should forward its own triples.
"""
q_triples = sample_question_triples()
f_triples = sample_focus_triples()
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=q_triples,
),
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:focus:abc",
explain_graph="urn:graph:retrieval",
explain_triples=f_triples,
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append({
"explain_id": explain_id,
"explain_triples": explain_triples,
})
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
await client.rag(query="test", explain_callback=explain_callback)
assert len(received_calls) == 2
assert received_calls[0]["explain_id"] == "urn:trustgraph:question:abc"
assert received_calls[0]["explain_triples"] == q_triples
assert received_calls[1]["explain_id"] == "urn:trustgraph:focus:abc"
assert received_calls[1]["explain_triples"] == f_triples
@pytest.mark.asyncio
async def test_no_explain_callback_does_not_error(self):
"""
When no explain_callback is provided, explain events should be
silently skipped without errors.
"""
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=sample_question_triples(),
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
result = await client.rag(query="test")
assert result == "Answer."

View file

@ -0,0 +1,482 @@
"""
Integration test: run a full GraphRag.query() with mocked subsidiary clients
and verify the explain_callback receives the complete provenance chain
in the correct order with correct structure.
This tests the real query() method end-to-end, not just the triple builders.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class EmbeddingMatch:
"""Mimics the result from graph_embeddings_client.query()."""
entity: Term
# ---------------------------------------------------------------------------
# Mock setup
# ---------------------------------------------------------------------------
# A tiny knowledge graph: 2 entities, 3 edges
ENTITY_A = "http://example.com/QuantumComputing"
ENTITY_B = "http://example.com/Physics"
EDGE_1 = (ENTITY_A, "http://schema.org/relatedTo", ENTITY_B)
EDGE_2 = (ENTITY_A, "http://schema.org/name", "Quantum Computing")
EDGE_3 = (ENTITY_B, "http://schema.org/name", "Physics")
def make_schema_triple(s, p, o):
"""Create a SchemaTriple from string values."""
return SchemaTriple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o) if o.startswith("http") else Term(type=LITERAL, value=o),
)
def build_mock_clients():
"""
Build mock clients that simulate a small knowledge graph query.
Client call sequence during query():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. graph_embeddings_client.query(vector, ...) -> entity matches
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
# 1. Concept extraction
prompt_responses = {}
prompt_responses["extract-concepts"] = "quantum computing\nphysics"
# 2. Embedding vectors (simple fake vectors)
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# 3. Entity lookup - return our two entities
graph_embeddings_client.query.return_value = [
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_A)),
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
]
# 4. Triple queries (follow_edges_batch) - return our edges
kg_triples = [
make_schema_triple(*EDGE_1),
make_schema_triple(*EDGE_2),
make_schema_triple(*EDGE_3),
]
triples_client.query_stream.return_value = kg_triples
# 5. Label resolution - return entity as its own label (simplify)
async def mock_label_query(s=None, p=None, o=None, limit=1,
user=None, collection=None, g=None):
return [] # No labels found, will fall back to URI
triples_client.query.side_effect = mock_label_query
# 6+7. Edge scoring and reasoning: dynamically score/reason about
# whatever edges the query method sends us, since edge IDs are computed
# from str(Term) representations which include the full dataclass repr.
synthesis_answer = "Quantum computing applies physics principles to computation."
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return prompt_responses["extract-concepts"]
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
]
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
]
elif template_id == "kg-synthesis":
return synthesis_answer
return ""
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestGraphRagQueryProvenance:
"""
Run a real GraphRag.query() and verify the provenance chain emitted
via explain_callback.
"""
@pytest.mark.asyncio
async def test_explain_callback_receives_five_events(self):
"""query() should emit exactly 5 explain events."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0, # skip semantic pre-filter for simplicity
)
assert len(events) == 5, (
f"Expected 5 explain events (question, grounding, exploration, "
f"focus, synthesis), got {len(events)}"
)
@pytest.mark.asyncio
async def test_events_have_correct_types_in_order(self):
"""
Events should arrive as:
question, grounding, exploration, focus, synthesis.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
expected_types = [
TG_GRAPH_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_FOCUS,
TG_SYNTHESIS,
]
for i, expected_type in enumerate(expected_types):
uri = events[i]["explain_id"]
triples = events[i]["triples"]
assert has_type(triples, uri, expected_type), (
f"Event {i} (uri={uri}) should have type {expected_type}"
)
@pytest.mark.asyncio
async def test_derivation_chain_links_correctly(self):
"""
Each event's URI should link to the previous via wasDerivedFrom:
grounding question (none)
exploration grounding
focus exploration
synthesis focus
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
uris = [e["explain_id"] for e in events]
all_triples = []
for e in events:
all_triples.extend(e["triples"])
# question has no parent
assert derived_from(all_triples, uris[0]) is None
# grounding → question
assert derived_from(all_triples, uris[1]) == uris[0]
# exploration → grounding
assert derived_from(all_triples, uris[2]) == uris[1]
# focus → exploration
assert derived_from(all_triples, uris[3]) == uris[2]
# synthesis → focus
assert derived_from(all_triples, uris[4]) == uris[3]
@pytest.mark.asyncio
async def test_question_event_carries_query_text(self):
"""The question event should contain the original query string."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
t = find_triple(q_triples, TG_QUERY, q_uri)
assert t is not None
assert t.o.value == "What is quantum computing?"
@pytest.mark.asyncio
async def test_grounding_carries_concepts(self):
"""The grounding event should list extracted concepts."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
gnd_uri = events[1]["explain_id"]
gnd_triples = events[1]["triples"]
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
concept_values = {t.o.value for t in concepts}
assert "quantum computing" in concept_values
assert "physics" in concept_values
@pytest.mark.asyncio
async def test_exploration_has_edge_count(self):
"""The exploration event should report how many edges were found."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
t = find_triple(exp_triples, TG_EDGE_COUNT, exp_uri)
assert t is not None
# Should be non-zero (we provided 3 edges, label edges filtered)
assert int(t.o.value) > 0
@pytest.mark.asyncio
async def test_focus_has_selected_edges_with_reasoning(self):
"""
The focus event should carry selected edges as quoted triples
with reasoning text.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
foc_uri = events[3]["explain_id"]
foc_triples = events[3]["triples"]
# Should have selected edges
selected = find_triples(foc_triples, TG_SELECTED_EDGE, foc_uri)
assert len(selected) > 0, "Focus should have at least one selected edge"
# Each edge selection should have a quoted triple
edge_t = find_triples(foc_triples, TG_EDGE)
assert len(edge_t) > 0, "Focus should have tg:edge with quoted triples"
for t in edge_t:
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
# Should have reasoning
reasoning = find_triples(foc_triples, TG_REASONING)
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
reasoning_texts = {t.o.value for t in reasoning}
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
"""The synthesis event should have tg:Answer type."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
syn_uri = events[4]["explain_id"]
syn_triples = events[4]["triples"]
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
@pytest.mark.asyncio
async def test_query_returns_answer_text(self):
"""query() should still return the synthesised answer."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
result = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self):
"""When parent_uri is provided, question should derive from it."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
parent = "urn:trustgraph:agent:iteration:xyz"
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
parent_uri=parent,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
assert derived_from(q_triples, q_uri) == parent
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
"""query() without explain_callback should return answer normally."""
clients = build_mock_clients()
rag = GraphRag(*clients)
result = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All emitted triples should be in the urn:graph:retrieval graph."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
for event in events:
for t in event["triples"]:
assert t.g == "urn:graph:retrieval", (
f"Triple {t.s.iri} {t.p.iri} should be in "
f"urn:graph:retrieval, got {t.g}"
)

View file

@ -15,7 +15,7 @@ class GraphRagClient(RequestResponse):
user: User identifier
collection: Collection identifier
chunk_callback: Optional async callback(text, end_of_stream) for text chunks
explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications
explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications
timeout: Request timeout in seconds
Returns:
@ -30,7 +30,7 @@ class GraphRagClient(RequestResponse):
# Handle explain notifications
if resp.message_type == 'explain':
if explain_callback and resp.explain_id:
await explain_callback(resp.explain_id, resp.explain_graph)
await explain_callback(resp.explain_id, resp.explain_graph, resp.explain_triples)
return False # Continue receiving
# Handle text chunks

View file

@ -43,7 +43,7 @@ class DocumentRagClient(BaseClient):
user: User identifier
collection: Collection identifier
chunk_callback: Optional callback(text, end_of_stream) for text chunks
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications
timeout: Request timeout in seconds
Returns:
@ -55,7 +55,7 @@ class DocumentRagClient(BaseClient):
# Handle explain notifications (response is None/empty, explain_id present)
if x.explain_id and not x.response:
if explain_callback:
explain_callback(x.explain_id, x.explain_graph)
explain_callback(x.explain_id, x.explain_graph, x.explain_triples)
return False # Continue receiving
# Handle text chunks

View file

@ -47,7 +47,7 @@ class GraphRagClient(BaseClient):
user: User identifier
collection: Collection identifier
chunk_callback: Optional callback(text, end_of_stream) for text chunks
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications
timeout: Request timeout in seconds
Returns:
@ -59,7 +59,7 @@ class GraphRagClient(BaseClient):
# Handle explain notifications
if x.message_type == 'explain':
if explain_callback and x.explain_id:
explain_callback(x.explain_id, x.explain_graph)
explain_callback(x.explain_id, x.explain_graph, x.explain_triples)
return False # Continue receiving
# Handle text chunks

View file

@ -39,13 +39,14 @@ class KnowledgeQueryImpl:
if respond:
from ... schema import AgentResponse
async def explain_callback(explain_id, explain_graph):
async def explain_callback(explain_id, explain_graph, explain_triples=None):
self.context.last_sub_explain_uri = explain_id
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=explain_id,
explain_graph=explain_graph,
explain_triples=explain_triples or [],
))
if current_uri: