Compare commits

...

3 commits

Author SHA1 Message Date
cybermaggedon
aff96e57cb
Added Explainable AI agent demo in Typescript (#770)
(Not functional code)
2026-04-08 14:16:14 +01:00
cybermaggedon
e81418c58f
fix: preserve literal types in focus quoted triples and document tracing (#769)
The triples client returns Uri/Literal (str subclasses), not Term
objects.  _quoted_triple() treated all values as IRIs, so literal
objects like skos:definition values were mistyped in focus
provenance events, and trace_source_documents could not match
them in the store.

Added to_term() to convert Uri/Literal back to Term, threaded a
term_map from follow_edges_batch through
get_subgraph/get_labelgraph into uri_map, and updated
_quoted_triple to accept Term objects directly.
2026-04-08 13:37:02 +01:00
cybermaggedon
4b5bfacab1
Forward missing explain_triples through RAG clients and agent tool callback (#768)
fix: forward explain_triples through RAG clients and agent tool callback
- RAG clients and the KnowledgeQueryImpl tool callback were
  dropping explain_triples from explain events, losing provenance
  data (including focus edge selections) when graph-rag is invoked
  via the agent.

Tests for provenance and explainability (56 new):
- Client-level forwarding of explain_triples
- Graph-RAG structural chain
  (question → grounding → exploration → focus → synthesis)
- Graph-RAG integration with mocked subsidiary clients
- Document-RAG integration
  (question → grounding → exploration → synthesis)
- Agent-orchestrator all 3 patterns: react, plan-then-execute,
  supervisor
2026-04-08 11:41:17 +01:00
15 changed files with 2840 additions and 32 deletions

View file

@ -0,0 +1,29 @@
# Explainable AI Demo
Demonstrates the TrustGraph streaming agent API with inline explainability
events. Sends an agent query, receives streaming thinking/observation/answer
chunks alongside RDF provenance events, then resolves the full provenance
chain from answer back to source documents.
## What it shows
- Streaming agent responses (thinking, observation, answer)
- Inline explainability events with RDF triples (W3C PROV + TrustGraph namespace)
- Label resolution for entity and predicate URIs
- Provenance chain traversal: subgraph → chunk → page → document
- Source text retrieval from the librarian using chunk IDs
## Prerequisites
A running TrustGraph instance with at least one loaded document and a
running flow. The default configuration connects to `ws://localhost:8088`.
## Usage
```bash
npm install
node index.js
```
Edit the `QUESTION` and `SOCKET_URL` constants at the top of `index.js`
to change the query or target instance.

View file

@ -0,0 +1,552 @@
// ============================================================================
// TrustGraph Explainability API Demo
// ============================================================================
//
// This example demonstrates how to use the TrustGraph streaming agent API
// with explainability events. It shows how to:
//
// 1. Send an agent query and receive streaming thinking/observation/answer
// 2. Receive and parse explainability events as they arrive
// 3. Resolve the provenance chain for knowledge graph edges:
// subgraph -> chunk -> page -> document
// 4. Fetch source text from the librarian using chunk IDs
//
// Explainability events use RDF triples (W3C PROV ontology + TrustGraph
// namespace) to describe the retrieval pipeline. The key event types are:
//
// - AgentQuestion: The initial user query
// - Analysis/ToolUse: Agent deciding which tool to invoke
// - GraphRagQuestion: A sub-query sent to the Graph RAG pipeline
// - Grounding: Concepts extracted from the query for graph traversal
// - Exploration: Entities discovered during knowledge graph traversal
// - Focus: The selected knowledge graph edges (triples) used for context
// - Synthesis: The RAG answer synthesised from retrieved context
// - Observation: The tool result returned to the agent
// - Conclusion/Answer: The agent's final answer
//
// Each event carries RDF triples that link back through the provenance chain,
// allowing full traceability from answer back to source documents.
// ============================================================================
import { createTrustGraphSocket } from '@trustgraph/client';
// ---------------------------------------------------------------------------
// Configuration
// ---------------------------------------------------------------------------
const USER = "trustgraph";
// Simple question
const QUESTION = "Tell me about the author of the document";
// Likely to trigger the deep research plan-and-execute pattern
//const QUESTION = "Do deep research and explain the risks posed globalisation in the modern world";
const SOCKET_URL = "ws://localhost:8088/api/v1/socket";
// ---------------------------------------------------------------------------
// RDF predicates and TrustGraph namespace constants
// ---------------------------------------------------------------------------
const RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type";
const RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label";
const PROV_DERIVED = "http://www.w3.org/ns/prov#wasDerivedFrom";
const TG_GROUNDING = "https://trustgraph.ai/ns/Grounding";
const TG_CONCEPT = "https://trustgraph.ai/ns/concept";
const TG_EXPLORATION = "https://trustgraph.ai/ns/Exploration";
const TG_ENTITY = "https://trustgraph.ai/ns/entity";
const TG_FOCUS = "https://trustgraph.ai/ns/Focus";
const TG_EDGE = "https://trustgraph.ai/ns/edge";
const TG_CONTAINS = "https://trustgraph.ai/ns/contains";
// ---------------------------------------------------------------------------
// Utility: check whether a set of triples assigns a given RDF type to an ID
// ---------------------------------------------------------------------------
const isType = (triples, id, type) =>
triples.some(t => t.s.i === id && t.p.i === RDF_TYPE && t.o.i === type);
// ---------------------------------------------------------------------------
// Utility: word-wrap text for display
// ---------------------------------------------------------------------------
const wrapText = (text, width, indent, maxLines) => {
const clean = text.replace(/\s+/g, " ").trim();
const lines = [];
let remaining = clean;
while (remaining.length > 0 && lines.length < maxLines) {
if (remaining.length <= width) {
lines.push(remaining);
break;
}
let breakAt = remaining.lastIndexOf(" ", width);
if (breakAt <= 0) breakAt = width;
lines.push(remaining.substring(0, breakAt));
remaining = remaining.substring(breakAt).trimStart();
}
if (remaining.length > 0 && lines.length >= maxLines)
lines[lines.length - 1] += " ...";
return lines.map(l => indent + l).join("\n");
};
// ---------------------------------------------------------------------------
// Connect to TrustGraph
// ---------------------------------------------------------------------------
console.log("=".repeat(80));
console.log("TrustGraph Explainability API Demo");
console.log("=".repeat(80));
console.log(`Connecting to: ${SOCKET_URL}`);
console.log(`Question: ${QUESTION}`);
console.log("=".repeat(80));
const client = createTrustGraphSocket(USER, undefined, SOCKET_URL);
console.log("Connected, sending query...\n");
// Get a flow handle. Flows provide access to AI operations (agent, RAG,
// text completion, etc.) as well as knowledge graph queries.
const flow = client.flow("default");
// Get a librarian handle for fetching source document text.
const librarian = client.librarian();
// ---------------------------------------------------------------------------
// Inline explain event printing
// ---------------------------------------------------------------------------
// Explain events arrive during streaming alongside thinking/observation/
// answer chunks. We print a summary immediately and store them for
// post-processing (label resolution and provenance lookups require async
// queries that can't run inside the synchronous callback).
const explainEvents = [];
const printExplainInline = (explainEvent) => {
const { explainId, explainTriples } = explainEvent;
if (!explainTriples) return;
// Extract the RDF types assigned to the explain event's own ID.
// Every explain event has rdf:type triples that identify what kind
// of pipeline step it represents (Grounding, Exploration, Focus, etc.)
const types = explainTriples
.filter(t => t.s.i === explainId && t.p.i === RDF_TYPE)
.map(t => t.o.i);
// Show short type names (e.g. "Grounding" instead of full URI)
const shortTypes = types
.map(t => t.split("/").pop().split("#").pop())
.join(", ");
console.log(` [explain] ${shortTypes}`);
// Grounding events contain the concepts extracted from the query.
// These are the seed terms used to begin knowledge graph traversal.
if (isType(explainTriples, explainId, TG_GROUNDING)) {
const concepts = explainTriples
.filter(t => t.s.i === explainId && t.p.i === TG_CONCEPT)
.map(t => t.o.v);
console.log(` Grounding concepts: ${concepts.join(", ")}`);
}
// Exploration events list the entities found during graph traversal.
// We show the count here; labelled names are printed after resolution.
if (isType(explainTriples, explainId, TG_EXPLORATION)) {
const count = explainTriples
.filter(t => t.s.i === explainId && t.p.i === TG_ENTITY).length;
console.log(` Entities: ${count} found (see below)`);
}
};
const collectExplain = (explainEvent) => {
printExplainInline(explainEvent);
explainEvents.push(explainEvent);
};
// ---------------------------------------------------------------------------
// Label resolution
// ---------------------------------------------------------------------------
// Many explain triples reference entities and predicates by URI. We query
// the knowledge graph for rdfs:label to get human-readable names.
const resolveLabels = async (uris) => {
const labels = new Map();
await Promise.all(uris.map(async (uri) => {
try {
const results = await flow.triplesQuery(
{ t: "i", i: uri },
{ t: "i", i: RDFS_LABEL },
);
if (results.length > 0) {
labels.set(uri, results[0].o.v);
}
} catch (e) {
// No label found, fall back to URI
}
}));
return labels;
};
// ---------------------------------------------------------------------------
// Provenance resolution for knowledge graph edges
// ---------------------------------------------------------------------------
// Focus events contain the knowledge graph triples (edges) that were selected
// as context for the RAG answer. Each edge can be traced back through the
// provenance chain to the original source document:
//
// subgraph --contains--> <<edge triple>> (RDF-star triple term)
// subgraph --wasDerivedFrom--> chunk (text chunk)
// chunk --wasDerivedFrom--> page (document page)
// page --wasDerivedFrom--> document (original document)
//
// The chunk URI also serves as the content ID in the librarian, so it can
// be used to fetch the actual source text.
const resolveEdgeSources = async (edgeTriples) => {
const iri = (uri) => ({ t: "i", i: uri });
const sources = new Map();
await Promise.all(edgeTriples.map(async (tr) => {
const key = JSON.stringify(tr);
try {
// Step 1: Find the subgraph that contains this edge triple.
// The query uses an RDF-star triple term as the object: the
// knowledge graph stores subgraph -> contains -> <<s, p, o>>.
const subgraphResults = await flow.triplesQuery(
undefined,
iri(TG_CONTAINS),
{ t: "t", tr },
);
if (subgraphResults.length === 0) {
if (tr.o.t === "l" || tr.o.t === "i") {
console.log(` No source match for triple:`);
console.log(` s: ${tr.s.i}`);
console.log(` p: ${tr.p.i}`);
console.log(` o: ${JSON.stringify(tr.o)}`);
}
return;
}
const subgraph = subgraphResults[0].s.i;
// Step 2: Walk wasDerivedFrom chain: subgraph -> chunk
const chunkResults = await flow.triplesQuery(
iri(subgraph), iri(PROV_DERIVED),
);
if (chunkResults.length === 0) {
sources.set(key, { subgraph });
return;
}
const chunk = chunkResults[0].o.i;
// Step 3: chunk -> page
const pageResults = await flow.triplesQuery(
iri(chunk), iri(PROV_DERIVED),
);
if (pageResults.length === 0) {
sources.set(key, { subgraph, chunk });
return;
}
const page = pageResults[0].o.i;
// Step 4: page -> document
const docResults = await flow.triplesQuery(
iri(page), iri(PROV_DERIVED),
);
const document = docResults.length > 0 ? docResults[0].o.i : undefined;
sources.set(key, { subgraph, chunk, page, document });
} catch (e) {
// Query failed, skip this edge
}
}));
return sources;
};
// ---------------------------------------------------------------------------
// Collect URIs that need label resolution
// ---------------------------------------------------------------------------
// Scans explain events for entity URIs (from Exploration events) and edge
// term URIs (from Focus events) so we can batch-resolve their labels.
const collectUris = (events) => {
const uris = new Set();
for (const { explainId, explainTriples } of events) {
if (!explainTriples) continue;
// Entity URIs from exploration
if (isType(explainTriples, explainId, TG_EXPLORATION)) {
for (const t of explainTriples) {
if (t.s.i === explainId && t.p.i === TG_ENTITY)
uris.add(t.o.i);
}
}
// Subject, predicate, and object URIs from focus edge triples
if (isType(explainTriples, explainId, TG_FOCUS)) {
for (const t of explainTriples) {
if (t.p.i === TG_EDGE && t.o.t === "t") {
const tr = t.o.tr;
if (tr.s.t === "i") uris.add(tr.s.i);
if (tr.p.t === "i") uris.add(tr.p.i);
if (tr.o.t === "i") uris.add(tr.o.i);
}
}
}
}
return uris;
};
// ---------------------------------------------------------------------------
// Collect edge triples from Focus events
// ---------------------------------------------------------------------------
// Focus events contain selectedEdge -> edge relationships. Each edge's
// object is an RDF-star triple term ({t: "t", tr: {s, p, o}}) representing
// the actual knowledge graph triple used as RAG context.
const collectEdgeTriples = (events) => {
const edges = [];
for (const { explainId, explainTriples } of events) {
if (!explainTriples) continue;
if (isType(explainTriples, explainId, TG_FOCUS)) {
for (const t of explainTriples) {
if (t.p.i === TG_EDGE && t.o.t === "t")
edges.push(t.o.tr);
}
}
}
return edges;
};
// ---------------------------------------------------------------------------
// Print knowledge graph edges with provenance
// ---------------------------------------------------------------------------
// Displays each edge triple with resolved labels and its source location
// (chunk -> page -> document).
const printFocusEdges = (events, labels, edgeSources) => {
const label = (uri) => labels.get(uri) || uri;
for (const { explainId, explainTriples } of events) {
if (!explainTriples) continue;
if (!isType(explainTriples, explainId, TG_FOCUS)) continue;
const termValue = (term) =>
term.t === "i" ? label(term.i) : (term.v || "?");
const edges = explainTriples
.filter(t => t.p.i === TG_EDGE && t.o.t === "t")
.map(t => t.o.tr);
const display = edges.slice(0, 20);
for (const tr of display) {
console.log(` ${termValue(tr.s)} -> ${termValue(tr.p)} -> ${termValue(tr.o)}`);
const src = edgeSources.get(JSON.stringify(tr));
if (src) {
const parts = [];
if (src.chunk) parts.push(label(src.chunk));
if (src.page) parts.push(label(src.page));
if (src.document) parts.push(label(src.document));
if (parts.length > 0)
console.log(` Source: ${parts.join(" -> ")}`);
}
}
if (edges.length > 20)
console.log(` ... and ${edges.length - 20} more`);
}
};
// ---------------------------------------------------------------------------
// Fetch chunk text from the librarian
// ---------------------------------------------------------------------------
// The chunk URI (e.g. urn:chunk:UUID) serves as a universal ID that ties
// together provenance metadata, embeddings, and the source text content.
// The librarian stores the original text keyed by this same URI, so we
// can retrieve it with streamDocument(chunkUri).
const fetchChunkText = (chunkUri) => {
return new Promise((resolve, reject) => {
let text = "";
librarian.streamDocument(
chunkUri,
(content, chunkIndex, totalChunks, complete) => {
text += content;
if (complete) resolve(text);
},
(error) => reject(error),
);
});
};
// ===========================================================================
// Send the agent query
// ===========================================================================
// The agent callback receives four types of streaming content:
// - think: the agent's reasoning (chain-of-thought)
// - observe: tool results returned to the agent
// - answer: the final answer being generated
// - error: any errors during processing
//
// The onExplain callback fires for each explainability event, delivering
// RDF triples that describe what happened at each pipeline stage.
let thought = "";
let obs = "";
let ans = "";
await flow.agent(
QUESTION,
// Think callback: agent reasoning / chain-of-thought
(chunk, complete, messageId, metadata) => {
thought += chunk;
if (complete) {
console.log("\nThinking:", thought, "\n");
thought = "";
}
},
// Observe callback: tool results returned to the agent
(chunk, complete, messageId, metadata) => {
obs += chunk;
if (complete) {
console.log("\nObservation:", obs, "\n");
obs = "";
}
},
// Answer callback: the agent's final response
(chunk, complete, messageId, metadata) => {
ans += chunk;
if (complete) {
console.log("\nAnswer:", ans, "\n");
ans = "";
}
},
// Error callback
(error) => {
console.log(JSON.stringify({ type: "error", error }, null, 2));
},
// Explain callback: explainability events with RDF triples
(explainEvent) => {
collectExplain(explainEvent);
}
);
// ===========================================================================
// Post-processing: resolve labels, provenance, and source text
// ===========================================================================
// After the agent query completes, we have all the explain events. Now we
// can make async queries to:
// 1. Trace each edge back to its source document (provenance chain)
// 2. Resolve URIs to human-readable labels
// 3. Fetch the original text for each source chunk
console.log("Resolving provenance...\n");
// Resolve the provenance chain for each knowledge graph edge
const edgeTriples = collectEdgeTriples(explainEvents);
const edgeSources = await resolveEdgeSources(edgeTriples);
// Collect all URIs that need labels: entities, edge terms, and source URIs
const uris = collectUris(explainEvents);
for (const src of edgeSources.values()) {
if (src.chunk) uris.add(src.chunk);
if (src.page) uris.add(src.page);
if (src.document) uris.add(src.document);
}
const labels = await resolveLabels([...uris]);
const label = (uri) => labels.get(uri) || uri;
// ---------------------------------------------------------------------------
// Display: Entities retrieved during graph exploration
// ---------------------------------------------------------------------------
for (const { explainId, explainTriples } of explainEvents) {
if (!explainTriples) continue;
if (!isType(explainTriples, explainId, TG_EXPLORATION)) continue;
const entities = explainTriples
.filter(t => t.s.i === explainId && t.p.i === TG_ENTITY)
.map(t => label(t.o.i));
const display = entities.slice(0, 10);
console.log("=".repeat(80));
console.log("Entities Retrieved");
console.log("=".repeat(80));
console.log(` ${entities.length} entities: ${display.join(", ")}${entities.length > 10 ? ", ..." : ""}`);
}
// ---------------------------------------------------------------------------
// Display: Knowledge graph edges with provenance
// ---------------------------------------------------------------------------
console.log("\n" + "=".repeat(80));
console.log("Knowledge Graph Edges");
console.log("=".repeat(80));
printFocusEdges(explainEvents, labels, edgeSources);
// ---------------------------------------------------------------------------
// Display: Source text for each chunk referenced by the edges
// ---------------------------------------------------------------------------
const uniqueChunks = new Set();
for (const src of edgeSources.values()) {
if (src.chunk) uniqueChunks.add(src.chunk);
}
console.log(`\nFetching text for ${uniqueChunks.size} source chunks...`);
const chunkTexts = new Map();
await Promise.all([...uniqueChunks].map(async (chunkUri) => {
try {
// streamDocument returns base64-encoded content
const text = await fetchChunkText(chunkUri);
chunkTexts.set(chunkUri, text);
} catch (e) {
// Failed to fetch text for this chunk
}
}));
console.log("\n" + "=".repeat(80));
console.log("Sources");
console.log("=".repeat(80));
let sourceIndex = 0;
for (const chunkUri of uniqueChunks) {
sourceIndex++;
const chunkLabel = labels.get(chunkUri) || chunkUri;
// Find the page and document labels for this chunk
let pageLabel, docLabel;
for (const src of edgeSources.values()) {
if (src.chunk === chunkUri) {
if (src.page) pageLabel = labels.get(src.page) || src.page;
if (src.document) docLabel = labels.get(src.document) || src.document;
break;
}
}
console.log(`\n [${sourceIndex}] ${docLabel || "?"} / ${pageLabel || "?"} / ${chunkLabel}`);
console.log(" " + "-".repeat(70));
// Decode the base64 content and display a wrapped snippet
const b64 = chunkTexts.get(chunkUri);
if (b64) {
const text = Buffer.from(b64, "base64").toString("utf-8");
console.log(wrapText(text, 76, " ", 6));
}
}
// ---------------------------------------------------------------------------
// Clean up
// ---------------------------------------------------------------------------
console.log("\n" + "=".repeat(80));
console.log("Query complete");
console.log("=".repeat(80));
client.close();
process.exit(0);

View file

@ -0,0 +1,13 @@
{
"name": "explain-api-example",
"version": "1.0.0",
"description": "TrustGraph explainability API example",
"main": "index.js",
"type": "module",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"dependencies": {
"@trustgraph/client": "^1.7.2"
}
}

View file

@ -0,0 +1,655 @@
"""
Integration tests for agent-orchestrator provenance chains.
Tests all three patterns by calling iterate() with mocked dependencies
and verifying the explain events emitted via respond().
Provenance chains:
React: session iteration (observation or final)
Plan: session plan step-result(s) synthesis
Supervisor: session decomposition finding(s) synthesis
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
GRAPH_RETRIEVAL,
)
# Agent provenance type constants
from trustgraph.provenance.namespaces import (
TG_AGENT_QUESTION,
TG_ANALYSIS,
TG_TOOL_USE,
TG_OBSERVATION_TYPE,
TG_CONCLUSION,
TG_DECOMPOSITION,
TG_FINDING,
TG_PLAN_TYPE,
TG_STEP_RESULT,
TG_SYNTHESIS as TG_AGENT_SYNTHESIS,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
def collect_explain_events(respond_mock):
"""Extract explain events from a respond mock's call history."""
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
events.append({
"explain_id": resp.explain_id,
"explain_graph": resp.explain_graph,
"triples": resp.explain_triples,
})
return events
# ---------------------------------------------------------------------------
# Mock processor
# ---------------------------------------------------------------------------
def make_mock_processor(tools=None):
"""Build a mock processor with the minimal interface patterns need."""
processor = MagicMock()
processor.max_iterations = 10
processor.save_answer_content = AsyncMock()
# provenance_session_uri must return a real URI
def mock_session_uri(session_id):
return f"urn:trustgraph:agent:session:{session_id}"
processor.provenance_session_uri.side_effect = mock_session_uri
# Agent with tools
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agent = agent
# Aggregator for supervisor
processor.aggregator = MagicMock()
return processor
def make_mock_flow():
"""Build a mock flow that returns async mock producers."""
producers = {}
def flow_factory(name):
if name not in producers:
producers[name] = AsyncMock()
return producers[name]
flow = MagicMock(side_effect=flow_factory)
flow._producers = producers
return flow
def make_base_request(**kwargs):
"""Build a minimal AgentRequest."""
defaults = dict(
question="What is quantum computing?",
state="",
group=[],
history=[],
user="testuser",
collection="default",
streaming=False,
session_id="test-session-123",
conversation_id="",
pattern="react",
task_type="",
framing="",
correlation_id="",
parent_session_id="",
subagent_goal="",
expected_siblings=0,
)
defaults.update(kwargs)
return AgentRequest(**defaults)
# ---------------------------------------------------------------------------
# React pattern tests
# ---------------------------------------------------------------------------
class TestReactPatternProvenance:
"""
React pattern chain: session iteration final
(single iteration ending in Final answer)
"""
@pytest.mark.asyncio
async def test_single_iteration_final_answer(self):
"""
A single react iteration that produces a Final answer should emit:
session, iteration, final in that order.
"""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action, Final
processor = make_mock_processor()
pattern = ReactPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
request = make_base_request()
# Mock AgentManager.react to call on_action then return Final
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
final = Final(
thought="I know the answer",
final="Quantum computing uses qubits.",
)
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
thought="I know the answer",
name="final",
arguments={},
observation="",
))
return final
mock_am.react.side_effect = mock_react
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 3 events: session, iteration, final
assert len(events) == 3, (
f"Expected 3 explain events (session, iteration, final), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
# Check types
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_CONCLUSION)
# Check derivation chain
all_triples = []
for e in events:
all_triples.extend(e["triples"])
uris = [e["explain_id"] for e in events]
# iteration derives from session
assert derived_from(all_triples, uris[1]) == uris[0]
# final derives from session (first iteration, no prior observation)
assert derived_from(all_triples, uris[2]) == uris[0]
@pytest.mark.asyncio
async def test_iteration_with_tool_call(self):
"""
A react iteration that calls a tool (not Final) should emit:
session, iteration, observation then call next() for continuation.
"""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action
# Create a mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query the knowledge base"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="The answer is 42")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = make_mock_processor(
tools={"knowledge-query": mock_tool}
)
pattern = ReactPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
request = make_base_request()
action = Action(
thought="I need to look this up",
name="knowledge-query",
arguments={"question": "What is quantum computing?"},
observation="Quantum computing uses qubits.",
)
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
if on_action:
await on_action(action)
return action
mock_am.react.side_effect = mock_react
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 3 events: session, iteration, observation
assert len(events) == 3, (
f"Expected 3 explain events (session, iteration, observation), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_OBSERVATION_TYPE)
# next() should have been called to continue the loop
assert next_fn.called
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All explain triples should be in urn:graph:retrieval."""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action, Final
processor = make_mock_processor()
pattern = ReactPattern(processor)
respond = AsyncMock()
flow = make_mock_flow()
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
if on_action:
await on_action(Action(
thought="done", name="final",
arguments={}, observation="",
))
return Final(thought="done", final="answer")
mock_am.react.side_effect = mock_react
await pattern.iterate(
make_base_request(), respond, AsyncMock(), flow,
)
for event in collect_explain_events(respond):
for t in event["triples"]:
assert t.g == GRAPH_RETRIEVAL
# ---------------------------------------------------------------------------
# Plan-then-execute pattern tests
# ---------------------------------------------------------------------------
class TestPlanPatternProvenance:
"""
Plan pattern chain:
Planning iteration: session plan
Execution iterations: step-result(s) synthesis
"""
@pytest.mark.asyncio
async def test_planning_iteration_emits_session_and_plan(self):
"""
The first iteration (planning) should emit:
session, plan then call next() with the plan in history.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
processor = make_mock_processor()
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="plan")
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 2 events: session, plan
assert len(events) == 2, (
f"Expected 2 explain events (session, plan), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_PLAN_TYPE)
# Plan should derive from session
all_triples = []
for e in events:
all_triples.extend(e["triples"])
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
# next() should have been called with plan in history
assert next_fn.called
@pytest.mark.asyncio
async def test_execution_iteration_emits_step_result(self):
"""
An execution iteration should emit a step-result event.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
# Create a mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query KB"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="Found the answer")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = make_mock_processor(
tools={"knowledge-query": mock_tool}
)
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = {
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
}
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with plan already in history (second iteration)
plan_step = AgentStep(
thought="Created plan",
action="plan",
arguments={},
observation="[]",
step_type="plan",
plan=[
PlanStep(goal="Find info", tool_hint="knowledge-query",
depends_on=[], status="pending", result=""),
],
)
request = make_base_request(
pattern="plan",
history=[plan_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have step-result (no session on iteration > 1)
step_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_STEP_RESULT)
]
assert len(step_events) == 1, (
f"Expected 1 step-result event, got {len(step_events)}"
)
@pytest.mark.asyncio
async def test_synthesis_after_all_steps_complete(self):
"""
When all plan steps are completed, synthesis should be emitted.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
processor = make_mock_processor()
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer."
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with all steps completed
exec_step = AgentStep(
thought="Executing step",
action="knowledge-query",
arguments={},
observation="Result",
step_type="execute",
plan=[
PlanStep(goal="Find info", tool_hint="knowledge-query",
depends_on=[], status="completed", result="Found it"),
],
)
request = make_base_request(
pattern="plan",
history=[exec_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have synthesis event
synth_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
]
assert len(synth_events) == 1, (
f"Expected 1 synthesis event, got {len(synth_events)}"
)
# ---------------------------------------------------------------------------
# Supervisor pattern tests
# ---------------------------------------------------------------------------
class TestSupervisorPatternProvenance:
"""
Supervisor pattern chain:
Decompose: session decomposition
(Fan-out to subagents happens externally)
Synthesise: synthesis (derives from findings)
"""
@pytest.mark.asyncio
async def test_decompose_emits_session_and_decomposition(self):
"""
The decompose phase should emit: session, decomposition.
"""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
"What is quantum computing?",
"What are qubits?",
]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="supervisor")
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 2 events: session, decomposition
assert len(events) == 2, (
f"Expected 2 explain events (session, decomposition), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_DECOMPOSITION)
# Decomposition derives from session
all_triples = []
for e in events:
all_triples.extend(e["triples"])
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
@pytest.mark.asyncio
async def test_synthesis_emits_after_subagent_results(self):
"""
When subagent results arrive, synthesis should be emitted.
"""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer."
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with subagent results in history
synth_step = AgentStep(
thought="",
action="synthesise",
arguments={},
observation="",
step_type="synthesise",
subagent_results={
"What is quantum computing?": "It uses qubits",
"What are qubits?": "Quantum bits",
},
)
request = make_base_request(
pattern="supervisor",
history=[synth_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have synthesis event (no session on iteration > 1)
synth_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
]
assert len(synth_events) == 1
@pytest.mark.asyncio
async def test_decompose_fans_out_subagents(self):
"""The decompose phase should call next() for each subagent goal."""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="supervisor")
await pattern.iterate(request, respond, next_fn, flow)
# 3 subagent requests fanned out
assert next_fn.call_count == 3

View file

@ -0,0 +1,295 @@
"""
Structural test for the graph-rag provenance chain.
Verifies that a complete graph-rag query produces the expected
provenance chain:
question grounding exploration focus synthesis
Each step must:
- Have the correct rdf:type
- Link to its predecessor via prov:wasDerivedFrom
- Carry expected domain-specific data
"""
import pytest
from trustgraph.provenance.triples import (
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
)
from trustgraph.provenance.uris import (
question_uri,
grounding_uri,
exploration_uri,
focus_uri,
synthesis_uri,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_GRAPH_RAG_QUESTION, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT,
PROV_STARTED_AT_TIME,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SESSION_ID = "test-session-1234"
def find_triple(triples, predicate, subject=None):
"""Find first triple matching predicate (and optionally subject)."""
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
"""Find all triples matching predicate (and optionally subject)."""
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
"""Check if subject has the given rdf:type."""
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
"""Get the wasDerivedFrom target URI for a subject."""
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
# ---------------------------------------------------------------------------
# Build the full chain
# ---------------------------------------------------------------------------
@pytest.fixture
def chain():
"""Build all provenance triples for a complete graph-rag query."""
q_uri = question_uri(SESSION_ID)
gnd_uri = grounding_uri(SESSION_ID)
exp_uri = exploration_uri(SESSION_ID)
foc_uri = focus_uri(SESSION_ID)
syn_uri = synthesis_uri(SESSION_ID)
q = question_triples(q_uri, "What is quantum computing?", "2026-01-01T00:00:00Z")
gnd = grounding_triples(gnd_uri, q_uri, ["quantum", "computing"])
exp = exploration_triples(
exp_uri, gnd_uri, edge_count=42,
entities=["urn:entity:1", "urn:entity:2"],
)
foc = focus_triples(
foc_uri, exp_uri,
selected_edges_with_reasoning=[
{
"edge": (
"http://example.com/QuantumComputing",
"http://schema.org/relatedTo",
"http://example.com/Physics",
),
"reasoning": "Directly relevant to the query",
},
{
"edge": (
"http://example.com/QuantumComputing",
"http://schema.org/name",
"Quantum Computing",
),
"reasoning": "Provides the entity label",
},
],
session_id=SESSION_ID,
)
syn = synthesis_triples(syn_uri, foc_uri, document_id="urn:doc:answer-1")
return {
"uris": {
"question": q_uri,
"grounding": gnd_uri,
"exploration": exp_uri,
"focus": foc_uri,
"synthesis": syn_uri,
},
"triples": {
"question": q,
"grounding": gnd,
"exploration": exp,
"focus": foc,
"synthesis": syn,
},
"all": q + gnd + exp + foc + syn,
}
# ---------------------------------------------------------------------------
# Chain structure tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceChain:
"""Verify the full question → grounding → exploration → focus → synthesis chain."""
def test_chain_has_five_stages(self, chain):
"""Each stage should produce at least some triples."""
for stage in ["question", "grounding", "exploration", "focus", "synthesis"]:
assert len(chain["triples"][stage]) > 0, f"{stage} produced no triples"
def test_derivation_chain(self, chain):
"""
The wasDerivedFrom links must form:
grounding question, exploration grounding,
focus exploration, synthesis focus.
"""
uris = chain["uris"]
all_triples = chain["all"]
assert derived_from(all_triples, uris["grounding"]) == uris["question"]
assert derived_from(all_triples, uris["exploration"]) == uris["grounding"]
assert derived_from(all_triples, uris["focus"]) == uris["exploration"]
assert derived_from(all_triples, uris["synthesis"]) == uris["focus"]
def test_question_has_no_parent(self, chain):
"""The root question should not derive from anything (no parent_uri)."""
uris = chain["uris"]
all_triples = chain["all"]
assert derived_from(all_triples, uris["question"]) is None
def test_question_with_parent(self):
"""When a parent_uri is given, question should derive from it."""
q_uri = question_uri("child-session")
parent = "urn:trustgraph:agent:iteration:parent"
q = question_triples(q_uri, "sub-query", "2026-01-01T00:00:00Z",
parent_uri=parent)
assert derived_from(q, q_uri) == parent
# ---------------------------------------------------------------------------
# Type annotation tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceTypes:
"""Each stage must have the correct rdf:type annotations."""
def test_question_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["question"]
assert has_type(triples, uris["question"], PROV_ENTITY)
assert has_type(triples, uris["question"], TG_GRAPH_RAG_QUESTION)
def test_grounding_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["grounding"]
assert has_type(triples, uris["grounding"], PROV_ENTITY)
assert has_type(triples, uris["grounding"], TG_GROUNDING)
def test_exploration_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["exploration"]
assert has_type(triples, uris["exploration"], PROV_ENTITY)
assert has_type(triples, uris["exploration"], TG_EXPLORATION)
def test_focus_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["focus"]
assert has_type(triples, uris["focus"], PROV_ENTITY)
assert has_type(triples, uris["focus"], TG_FOCUS)
def test_synthesis_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["synthesis"]
assert has_type(triples, uris["synthesis"], PROV_ENTITY)
assert has_type(triples, uris["synthesis"], TG_SYNTHESIS)
assert has_type(triples, uris["synthesis"], TG_ANSWER_TYPE)
# ---------------------------------------------------------------------------
# Domain-specific content tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceContent:
"""Each stage should carry the expected domain data."""
def test_question_has_query_text(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["question"], TG_QUERY, uris["question"])
assert t is not None
assert t.o.value == "What is quantum computing?"
def test_question_has_timestamp(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["question"], PROV_STARTED_AT_TIME, uris["question"])
assert t is not None
assert t.o.value == "2026-01-01T00:00:00Z"
def test_grounding_has_concepts(self, chain):
uris = chain["uris"]
concepts = find_triples(chain["triples"]["grounding"], TG_CONCEPT, uris["grounding"])
concept_values = {t.o.value for t in concepts}
assert concept_values == {"quantum", "computing"}
def test_exploration_has_edge_count(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["exploration"], TG_EDGE_COUNT, uris["exploration"])
assert t is not None
assert t.o.value == "42"
def test_exploration_has_entities(self, chain):
uris = chain["uris"]
entities = find_triples(chain["triples"]["exploration"], TG_ENTITY, uris["exploration"])
entity_iris = {t.o.iri for t in entities}
assert entity_iris == {"urn:entity:1", "urn:entity:2"}
def test_focus_has_selected_edges(self, chain):
uris = chain["uris"]
edges = find_triples(chain["triples"]["focus"], TG_SELECTED_EDGE, uris["focus"])
assert len(edges) == 2
def test_focus_edges_have_quoted_triples(self, chain):
"""Each edge selection entity should have a tg:edge with a quoted triple."""
focus = chain["triples"]["focus"]
edge_triples = find_triples(focus, TG_EDGE)
assert len(edge_triples) == 2
# Each should have a quoted triple as the object
for t in edge_triples:
assert t.o.triple is not None, "tg:edge object should be a quoted triple"
def test_focus_edges_have_reasoning(self, chain):
"""Each edge selection entity should have tg:reasoning."""
focus = chain["triples"]["focus"]
reasoning = find_triples(focus, TG_REASONING)
assert len(reasoning) == 2
reasoning_texts = {t.o.value for t in reasoning}
assert "Directly relevant to the query" in reasoning_texts
assert "Provides the entity label" in reasoning_texts
def test_synthesis_has_document_ref(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["synthesis"], TG_DOCUMENT, uris["synthesis"])
assert t is not None
assert t.o.iri == "urn:doc:answer-1"
def test_synthesis_has_labels(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["synthesis"], RDFS_LABEL, uris["synthesis"])
assert t is not None
assert t.o.value == "Synthesis"

View file

@ -0,0 +1,380 @@
"""
Integration test: run a full DocumentRag.query() with mocked subsidiary
clients and verify the explain_callback receives the complete provenance
chain in the correct order with correct structure.
Document-RAG provenance chain (4 stages):
question grounding exploration synthesis
"""
import pytest
from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT,
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class ChunkMatch:
"""Mimics the result from doc_embeddings_client.query()."""
chunk_id: str
# ---------------------------------------------------------------------------
# Mock setup
# ---------------------------------------------------------------------------
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
CHUNK_B_CONTENT = "Refunds are processed to the original payment method."
def build_mock_clients():
"""
Build mock clients for a document-rag query.
Client call sequence during query():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. doc_embeddings_client.query(vector, ...) -> chunk matches
4. fetch_chunk(chunk_id, user) -> chunk content
5. prompt_client.document_prompt(query, documents) -> answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock()
# 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return "return policy\nrefund"
return ""
prompt_client.prompt.side_effect = mock_prompt
# 2. Embedding vectors
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# 3. Chunk matching
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id=CHUNK_A),
ChunkMatch(chunk_id=CHUNK_B),
]
# 4. Chunk content
async def mock_fetch(chunk_id, user):
return {
CHUNK_A: CHUNK_A_CONTENT,
CHUNK_B: CHUNK_B_CONTENT,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
# 5. Synthesis
prompt_client.document_prompt.return_value = (
"Items can be returned within 30 days for a full refund."
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestDocumentRagQueryProvenance:
"""
Run a real DocumentRag.query() and verify the provenance chain emitted
via explain_callback.
"""
@pytest.mark.asyncio
async def test_explain_callback_receives_four_events(self):
"""query() should emit exactly 4 explain events."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
assert len(events) == 4, (
f"Expected 4 explain events (question, grounding, exploration, "
f"synthesis), got {len(events)}"
)
@pytest.mark.asyncio
async def test_events_have_correct_types_in_order(self):
"""
Events should arrive as:
question, grounding, exploration, synthesis.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
expected_types = [
TG_DOC_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_SYNTHESIS,
]
for i, expected_type in enumerate(expected_types):
uri = events[i]["explain_id"]
triples = events[i]["triples"]
assert has_type(triples, uri, expected_type), (
f"Event {i} (uri={uri}) should have type {expected_type}"
)
@pytest.mark.asyncio
async def test_derivation_chain_links_correctly(self):
"""
Each event's URI should link to the previous via wasDerivedFrom:
question (none)
grounding question
exploration grounding
synthesis exploration
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
uris = [e["explain_id"] for e in events]
all_triples = []
for e in events:
all_triples.extend(e["triples"])
# question has no parent
assert derived_from(all_triples, uris[0]) is None
# grounding → question
assert derived_from(all_triples, uris[1]) == uris[0]
# exploration → grounding
assert derived_from(all_triples, uris[2]) == uris[1]
# synthesis → exploration
assert derived_from(all_triples, uris[3]) == uris[2]
@pytest.mark.asyncio
async def test_question_carries_query_text(self):
"""The question event should contain the original query string."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
t = find_triple(q_triples, TG_QUERY, q_uri)
assert t is not None
assert t.o.value == "What is the return policy?"
@pytest.mark.asyncio
async def test_grounding_carries_concepts(self):
"""The grounding event should list extracted concepts."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
gnd_uri = events[1]["explain_id"]
gnd_triples = events[1]["triples"]
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
concept_values = {t.o.value for t in concepts}
assert "return policy" in concept_values
assert "refund" in concept_values
@pytest.mark.asyncio
async def test_exploration_has_chunk_count(self):
"""The exploration event should report the number of chunks retrieved."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
t = find_triple(exp_triples, TG_CHUNK_COUNT, exp_uri)
assert t is not None
assert int(t.o.value) == 2
@pytest.mark.asyncio
async def test_exploration_has_selected_chunks(self):
"""The exploration event should list the chunk IDs that were fetched."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
chunks = find_triples(exp_triples, TG_SELECTED_CHUNK, exp_uri)
chunk_iris = {t.o.iri for t in chunks}
assert CHUNK_A in chunk_iris
assert CHUNK_B in chunk_iris
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
"""The synthesis event should have tg:Synthesis and tg:Answer types."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
syn_uri = events[3]["explain_id"]
syn_triples = events[3]["triples"]
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
@pytest.mark.asyncio
async def test_query_returns_answer_text(self):
"""query() should return the synthesised answer."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(
query="What is the return policy?",
explain_callback=AsyncMock(),
)
assert result == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
"""query() without explain_callback should return answer normally."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(query="What is the return policy?")
assert result == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All emitted triples should be in the urn:graph:retrieval graph."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
for event in events:
for t in event["triples"]:
assert t.g == "urn:graph:retrieval", (
f"Triple {t.s.iri} {t.p.iri} should be in "
f"urn:graph:retrieval, got {t.g}"
)

View file

@ -465,12 +465,15 @@ class TestQuery:
return_value=(["entity1", "entity2"], ["concept1"]) return_value=(["entity1", "entity2"], ["concept1"])
) )
query.follow_edges_batch = AsyncMock(return_value={ query.follow_edges_batch = AsyncMock(return_value=(
("entity1", "predicate1", "object1"), {
("entity2", "predicate2", "object2") ("entity1", "predicate1", "object1"),
}) ("entity2", "predicate2", "object2")
},
{}
))
subgraph, entities, concepts = await query.get_subgraph("test query") subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
query.get_entities.assert_called_once_with("test query") query.get_entities.assert_called_once_with("test query")
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
@ -503,7 +506,7 @@ class TestQuery:
test_entities = ["entity1", "entity3"] test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"] test_concepts = ["concept1"]
query.get_subgraph = AsyncMock( query.get_subgraph = AsyncMock(
return_value=(test_subgraph, test_entities, test_concepts) return_value=(test_subgraph, {}, test_entities, test_concepts)
) )
async def mock_maybe_label(entity): async def mock_maybe_label(entity):

View file

@ -0,0 +1,358 @@
"""
Tests that explain_triples are forwarded correctly through the graph-rag
service and client layers.
Covers:
- Service: explain messages include triples from the provenance callback
- Client: explain_callback receives explain_triples from the response
- End-to-end: triples survive the full service client callback chain
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.schema import (
GraphRagQuery, GraphRagResponse,
Triple, Term, IRI, LITERAL,
)
from trustgraph.base.graph_rag_client import GraphRagClient
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_triple(s_iri, p_iri, o_value, o_type=IRI):
"""Create a Triple with IRI subject/predicate and typed object."""
o = (
Term(type=IRI, iri=o_value) if o_type == IRI
else Term(type=LITERAL, value=o_value)
)
return Triple(
s=Term(type=IRI, iri=s_iri),
p=Term(type=IRI, iri=p_iri),
o=o,
)
def sample_focus_triples():
"""Focus-style triples with a quoted triple (edge selection)."""
return [
make_triple(
"urn:trustgraph:focus:abc",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"https://trustgraph.ai/ns/Focus",
),
make_triple(
"urn:trustgraph:focus:abc",
"http://www.w3.org/ns/prov#wasDerivedFrom",
"urn:trustgraph:exploration:abc",
),
make_triple(
"urn:trustgraph:focus:abc",
"https://trustgraph.ai/ns/selectedEdge",
"urn:trustgraph:edge-sel:abc:0",
),
]
def sample_question_triples():
"""Question-style triples."""
return [
make_triple(
"urn:trustgraph:question:abc",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"https://trustgraph.ai/ns/GraphRagQuestion",
),
make_triple(
"urn:trustgraph:question:abc",
"https://trustgraph.ai/ns/query",
"What is quantum computing?",
o_type=LITERAL,
),
]
# ---------------------------------------------------------------------------
# Service-level: explain messages carry triples
# ---------------------------------------------------------------------------
class TestGraphRagServiceExplainTriples:
"""Test that the graph-rag service includes explain_triples in messages."""
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
@pytest.mark.asyncio
async def test_explain_messages_include_triples(self, mock_graph_rag_class):
"""
When the provenance callback is invoked with triples, the service
should include them in the explain response message.
"""
from trustgraph.retrieval.graph_rag.rag import Processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
entity_limit=50,
triple_limit=30,
max_subgraph_size=150,
max_path_length=2,
)
mock_rag_instance = AsyncMock()
mock_graph_rag_class.return_value = mock_rag_instance
question_triples = sample_question_triples()
focus_triples = sample_focus_triples()
async def mock_query(**kwargs):
explain_callback = kwargs.get('explain_callback')
if explain_callback:
await explain_callback(
question_triples, "urn:trustgraph:question:abc"
)
await explain_callback(
focus_triples, "urn:trustgraph:focus:abc"
)
return "The answer."
mock_rag_instance.query.side_effect = mock_query
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is quantum computing?",
user="trustgraph",
collection="default",
streaming=False,
)
msg.properties.return_value = {"id": "test-id"}
consumer = MagicMock()
flow = MagicMock()
mock_response = AsyncMock()
mock_provenance = AsyncMock()
def flow_router(name):
if name == "response":
return mock_response
if name == "explainability":
return mock_provenance
return AsyncMock()
flow.side_effect = flow_router
await processor.on_request(msg, consumer, flow)
# Find the explain messages
explain_msgs = [
call[0][0]
for call in mock_response.send.call_args_list
if call[0][0].message_type == "explain"
]
assert len(explain_msgs) == 2
# First explain message should carry question triples
assert explain_msgs[0].explain_id == "urn:trustgraph:question:abc"
assert explain_msgs[0].explain_triples == question_triples
# Second explain message should carry focus triples
assert explain_msgs[1].explain_id == "urn:trustgraph:focus:abc"
assert explain_msgs[1].explain_triples == focus_triples
# ---------------------------------------------------------------------------
# Client-level: explain_callback receives triples
# ---------------------------------------------------------------------------
class TestGraphRagClientExplainForwarding:
"""Test that GraphRagClient.rag() forwards explain_triples to callback."""
@pytest.mark.asyncio
async def test_explain_callback_receives_triples(self):
"""
The explain_callback should receive (explain_id, explain_graph,
explain_triples) not just (explain_id, explain_graph).
"""
focus_triples = sample_focus_triples()
# Simulate the response sequence the client would receive
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:focus:abc",
explain_graph="urn:graph:retrieval",
explain_triples=focus_triples,
),
GraphRagResponse(
message_type="chunk",
response="The answer.",
end_of_stream=True,
),
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
),
]
# Capture what the explain_callback receives
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append({
"explain_id": explain_id,
"explain_graph": explain_graph,
"explain_triples": explain_triples,
})
# Patch self.request to feed responses to the recipient
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
result = await client.rag(
query="test",
explain_callback=explain_callback,
)
assert result == "The answer."
assert len(received_calls) == 1
assert received_calls[0]["explain_id"] == "urn:trustgraph:focus:abc"
assert received_calls[0]["explain_graph"] == "urn:graph:retrieval"
assert received_calls[0]["explain_triples"] == focus_triples
@pytest.mark.asyncio
async def test_explain_callback_receives_empty_triples(self):
"""
When an explain event has no triples, the callback should still
receive an empty list (not None or missing).
"""
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=[],
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append(explain_triples)
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
await client.rag(query="test", explain_callback=explain_callback)
assert len(received_calls) == 1
assert received_calls[0] == []
@pytest.mark.asyncio
async def test_multiple_explain_events_all_forward_triples(self):
"""
Each explain event in a session should forward its own triples.
"""
q_triples = sample_question_triples()
f_triples = sample_focus_triples()
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=q_triples,
),
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:focus:abc",
explain_graph="urn:graph:retrieval",
explain_triples=f_triples,
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append({
"explain_id": explain_id,
"explain_triples": explain_triples,
})
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
await client.rag(query="test", explain_callback=explain_callback)
assert len(received_calls) == 2
assert received_calls[0]["explain_id"] == "urn:trustgraph:question:abc"
assert received_calls[0]["explain_triples"] == q_triples
assert received_calls[1]["explain_id"] == "urn:trustgraph:focus:abc"
assert received_calls[1]["explain_triples"] == f_triples
@pytest.mark.asyncio
async def test_no_explain_callback_does_not_error(self):
"""
When no explain_callback is provided, explain events should be
silently skipped without errors.
"""
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=sample_question_triples(),
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
result = await client.rag(query="test")
assert result == "Answer."

View file

@ -0,0 +1,482 @@
"""
Integration test: run a full GraphRag.query() with mocked subsidiary clients
and verify the explain_callback receives the complete provenance chain
in the correct order with correct structure.
This tests the real query() method end-to-end, not just the triple builders.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class EmbeddingMatch:
"""Mimics the result from graph_embeddings_client.query()."""
entity: Term
# ---------------------------------------------------------------------------
# Mock setup
# ---------------------------------------------------------------------------
# A tiny knowledge graph: 2 entities, 3 edges
ENTITY_A = "http://example.com/QuantumComputing"
ENTITY_B = "http://example.com/Physics"
EDGE_1 = (ENTITY_A, "http://schema.org/relatedTo", ENTITY_B)
EDGE_2 = (ENTITY_A, "http://schema.org/name", "Quantum Computing")
EDGE_3 = (ENTITY_B, "http://schema.org/name", "Physics")
def make_schema_triple(s, p, o):
"""Create a SchemaTriple from string values."""
return SchemaTriple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o) if o.startswith("http") else Term(type=LITERAL, value=o),
)
def build_mock_clients():
"""
Build mock clients that simulate a small knowledge graph query.
Client call sequence during query():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. graph_embeddings_client.query(vector, ...) -> entity matches
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
# 1. Concept extraction
prompt_responses = {}
prompt_responses["extract-concepts"] = "quantum computing\nphysics"
# 2. Embedding vectors (simple fake vectors)
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# 3. Entity lookup - return our two entities
graph_embeddings_client.query.return_value = [
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_A)),
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
]
# 4. Triple queries (follow_edges_batch) - return our edges
kg_triples = [
make_schema_triple(*EDGE_1),
make_schema_triple(*EDGE_2),
make_schema_triple(*EDGE_3),
]
triples_client.query_stream.return_value = kg_triples
# 5. Label resolution - return entity as its own label (simplify)
async def mock_label_query(s=None, p=None, o=None, limit=1,
user=None, collection=None, g=None):
return [] # No labels found, will fall back to URI
triples_client.query.side_effect = mock_label_query
# 6+7. Edge scoring and reasoning: dynamically score/reason about
# whatever edges the query method sends us, since edge IDs are computed
# from str(Term) representations which include the full dataclass repr.
synthesis_answer = "Quantum computing applies physics principles to computation."
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return prompt_responses["extract-concepts"]
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
]
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
]
elif template_id == "kg-synthesis":
return synthesis_answer
return ""
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestGraphRagQueryProvenance:
"""
Run a real GraphRag.query() and verify the provenance chain emitted
via explain_callback.
"""
@pytest.mark.asyncio
async def test_explain_callback_receives_five_events(self):
"""query() should emit exactly 5 explain events."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0, # skip semantic pre-filter for simplicity
)
assert len(events) == 5, (
f"Expected 5 explain events (question, grounding, exploration, "
f"focus, synthesis), got {len(events)}"
)
@pytest.mark.asyncio
async def test_events_have_correct_types_in_order(self):
"""
Events should arrive as:
question, grounding, exploration, focus, synthesis.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
expected_types = [
TG_GRAPH_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_FOCUS,
TG_SYNTHESIS,
]
for i, expected_type in enumerate(expected_types):
uri = events[i]["explain_id"]
triples = events[i]["triples"]
assert has_type(triples, uri, expected_type), (
f"Event {i} (uri={uri}) should have type {expected_type}"
)
@pytest.mark.asyncio
async def test_derivation_chain_links_correctly(self):
"""
Each event's URI should link to the previous via wasDerivedFrom:
grounding question (none)
exploration grounding
focus exploration
synthesis focus
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
uris = [e["explain_id"] for e in events]
all_triples = []
for e in events:
all_triples.extend(e["triples"])
# question has no parent
assert derived_from(all_triples, uris[0]) is None
# grounding → question
assert derived_from(all_triples, uris[1]) == uris[0]
# exploration → grounding
assert derived_from(all_triples, uris[2]) == uris[1]
# focus → exploration
assert derived_from(all_triples, uris[3]) == uris[2]
# synthesis → focus
assert derived_from(all_triples, uris[4]) == uris[3]
@pytest.mark.asyncio
async def test_question_event_carries_query_text(self):
"""The question event should contain the original query string."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
t = find_triple(q_triples, TG_QUERY, q_uri)
assert t is not None
assert t.o.value == "What is quantum computing?"
@pytest.mark.asyncio
async def test_grounding_carries_concepts(self):
"""The grounding event should list extracted concepts."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
gnd_uri = events[1]["explain_id"]
gnd_triples = events[1]["triples"]
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
concept_values = {t.o.value for t in concepts}
assert "quantum computing" in concept_values
assert "physics" in concept_values
@pytest.mark.asyncio
async def test_exploration_has_edge_count(self):
"""The exploration event should report how many edges were found."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
t = find_triple(exp_triples, TG_EDGE_COUNT, exp_uri)
assert t is not None
# Should be non-zero (we provided 3 edges, label edges filtered)
assert int(t.o.value) > 0
@pytest.mark.asyncio
async def test_focus_has_selected_edges_with_reasoning(self):
"""
The focus event should carry selected edges as quoted triples
with reasoning text.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
foc_uri = events[3]["explain_id"]
foc_triples = events[3]["triples"]
# Should have selected edges
selected = find_triples(foc_triples, TG_SELECTED_EDGE, foc_uri)
assert len(selected) > 0, "Focus should have at least one selected edge"
# Each edge selection should have a quoted triple
edge_t = find_triples(foc_triples, TG_EDGE)
assert len(edge_t) > 0, "Focus should have tg:edge with quoted triples"
for t in edge_t:
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
# Should have reasoning
reasoning = find_triples(foc_triples, TG_REASONING)
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
reasoning_texts = {t.o.value for t in reasoning}
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
"""The synthesis event should have tg:Answer type."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
syn_uri = events[4]["explain_id"]
syn_triples = events[4]["triples"]
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
@pytest.mark.asyncio
async def test_query_returns_answer_text(self):
"""query() should still return the synthesised answer."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
result = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self):
"""When parent_uri is provided, question should derive from it."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
parent = "urn:trustgraph:agent:iteration:xyz"
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
parent_uri=parent,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
assert derived_from(q_triples, q_uri) == parent
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
"""query() without explain_callback should return answer normally."""
clients = build_mock_clients()
rag = GraphRag(*clients)
result = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All emitted triples should be in the urn:graph:retrieval graph."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
for event in events:
for t in event["triples"]:
assert t.g == "urn:graph:retrieval", (
f"Triple {t.s.iri} {t.p.iri} should be in "
f"urn:graph:retrieval, got {t.g}"
)

View file

@ -15,7 +15,7 @@ class GraphRagClient(RequestResponse):
user: User identifier user: User identifier
collection: Collection identifier collection: Collection identifier
chunk_callback: Optional async callback(text, end_of_stream) for text chunks chunk_callback: Optional async callback(text, end_of_stream) for text chunks
explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications
timeout: Request timeout in seconds timeout: Request timeout in seconds
Returns: Returns:
@ -30,7 +30,7 @@ class GraphRagClient(RequestResponse):
# Handle explain notifications # Handle explain notifications
if resp.message_type == 'explain': if resp.message_type == 'explain':
if explain_callback and resp.explain_id: if explain_callback and resp.explain_id:
await explain_callback(resp.explain_id, resp.explain_graph) await explain_callback(resp.explain_id, resp.explain_graph, resp.explain_triples)
return False # Continue receiving return False # Continue receiving
# Handle text chunks # Handle text chunks

View file

@ -43,7 +43,7 @@ class DocumentRagClient(BaseClient):
user: User identifier user: User identifier
collection: Collection identifier collection: Collection identifier
chunk_callback: Optional callback(text, end_of_stream) for text chunks chunk_callback: Optional callback(text, end_of_stream) for text chunks
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications
timeout: Request timeout in seconds timeout: Request timeout in seconds
Returns: Returns:
@ -55,7 +55,7 @@ class DocumentRagClient(BaseClient):
# Handle explain notifications (response is None/empty, explain_id present) # Handle explain notifications (response is None/empty, explain_id present)
if x.explain_id and not x.response: if x.explain_id and not x.response:
if explain_callback: if explain_callback:
explain_callback(x.explain_id, x.explain_graph) explain_callback(x.explain_id, x.explain_graph, x.explain_triples)
return False # Continue receiving return False # Continue receiving
# Handle text chunks # Handle text chunks

View file

@ -47,7 +47,7 @@ class GraphRagClient(BaseClient):
user: User identifier user: User identifier
collection: Collection identifier collection: Collection identifier
chunk_callback: Optional callback(text, end_of_stream) for text chunks chunk_callback: Optional callback(text, end_of_stream) for text chunks
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications
timeout: Request timeout in seconds timeout: Request timeout in seconds
Returns: Returns:
@ -59,7 +59,7 @@ class GraphRagClient(BaseClient):
# Handle explain notifications # Handle explain notifications
if x.message_type == 'explain': if x.message_type == 'explain':
if explain_callback and x.explain_id: if explain_callback and x.explain_id:
explain_callback(x.explain_id, x.explain_graph) explain_callback(x.explain_id, x.explain_graph, x.explain_triples)
return False # Continue receiving return False # Continue receiving
# Handle text chunks # Handle text chunks

View file

@ -465,11 +465,18 @@ def exploration_triples(
return triples return triples
def _quoted_triple(s: str, p: str, o: str) -> Term: def _quoted_triple(s, p, o) -> Term:
"""Create a quoted triple term (RDF-star) from string values.""" """Create a quoted triple term (RDF-star).
Accepts either Term objects (preserving original types) or plain
strings (treated as IRIs for backward compatibility).
"""
s_term = s if isinstance(s, Term) else _iri(s)
p_term = p if isinstance(p, Term) else _iri(p)
o_term = o if isinstance(o, Term) else _iri(o)
return Term( return Term(
type=TRIPLE, type=TRIPLE,
triple=Triple(s=_iri(s), p=_iri(p), o=_iri(o)) triple=Triple(s=s_term, p=p_term, o=o_term)
) )

View file

@ -39,13 +39,14 @@ class KnowledgeQueryImpl:
if respond: if respond:
from ... schema import AgentResponse from ... schema import AgentResponse
async def explain_callback(explain_id, explain_graph): async def explain_callback(explain_id, explain_graph, explain_triples=None):
self.context.last_sub_explain_uri = explain_id self.context.last_sub_explain_uri = explain_id
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", chunk_type="explain",
content="", content="",
explain_id=explain_id, explain_id=explain_id,
explain_graph=explain_graph, explain_graph=explain_graph,
explain_triples=explain_triples or [],
)) ))
if current_uri: if current_uri:

View file

@ -10,6 +10,7 @@ from collections import OrderedDict
from datetime import datetime from datetime import datetime
from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
from ... knowledge import Uri, Literal
# Provenance imports # Provenance imports
from trustgraph.provenance import ( from trustgraph.provenance import (
@ -46,6 +47,26 @@ def term_to_string(term):
return term.iri or term.value or str(term) return term.iri or term.value or str(term)
def to_term(val):
"""Convert a Uri, Literal, or string to a schema Term.
The triples client returns Uri/Literal (str subclasses) rather than
Term objects. This converts them back so provenance quoted triples
preserve the correct type.
"""
if isinstance(val, Term):
return val
if isinstance(val, Uri):
return Term(type=IRI, iri=str(val))
if isinstance(val, Literal):
return Term(type=LITERAL, value=str(val))
# Fallback: treat as IRI if it looks like one, otherwise literal
s = str(val)
if s.startswith(("http://", "https://", "urn:")):
return Term(type=IRI, iri=s)
return Term(type=LITERAL, value=s)
def edge_id(s, p, o): def edge_id(s, p, o):
"""Generate an 8-character hash ID for an edge (s, p, o).""" """Generate an 8-character hash ID for an edge (s, p, o)."""
edge_str = f"{s}|{p}|{o}" edge_str = f"{s}|{p}|{o}"
@ -258,10 +279,18 @@ class Query:
return all_triples return all_triples
async def follow_edges_batch(self, entities, max_depth): async def follow_edges_batch(self, entities, max_depth):
"""Optimized iterative graph traversal with batching""" """Optimized iterative graph traversal with batching.
Returns:
tuple: (subgraph, term_map) where subgraph is a set of
(str, str, str) tuples and term_map maps each string tuple
to its original (Term, Term, Term) for type-preserving
provenance.
"""
visited = set() visited = set()
current_level = set(entities) current_level = set(entities)
subgraph = set() subgraph = set()
term_map = {} # (str, str, str) -> (Term, Term, Term)
for depth in range(max_depth): for depth in range(max_depth):
if not current_level or len(subgraph) >= self.max_subgraph_size: if not current_level or len(subgraph) >= self.max_subgraph_size:
@ -282,6 +311,7 @@ class Query:
for triple in triples: for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
subgraph.add(triple_tuple) subgraph.add(triple_tuple)
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
# Collect entities for next level (only from s and o positions) # Collect entities for next level (only from s and o positions)
if depth < max_depth - 1: # Don't collect for final depth if depth < max_depth - 1: # Don't collect for final depth
@ -293,13 +323,13 @@ class Query:
# Stop if subgraph size limit reached # Stop if subgraph size limit reached
if len(subgraph) >= self.max_subgraph_size: if len(subgraph) >= self.max_subgraph_size:
return subgraph return subgraph, term_map
# Update for next iteration # Update for next iteration
visited.update(current_level) visited.update(current_level)
current_level = next_level current_level = next_level
return subgraph return subgraph, term_map
async def follow_edges(self, ent, subgraph, path_length): async def follow_edges(self, ent, subgraph, path_length):
"""Legacy method - replaced by follow_edges_batch""" """Legacy method - replaced by follow_edges_batch"""
@ -311,7 +341,7 @@ class Query:
return return
# For backward compatibility, convert to new approach # For backward compatibility, convert to new approach
batch_result = await self.follow_edges_batch([ent], path_length) batch_result, _ = await self.follow_edges_batch([ent], path_length)
subgraph.update(batch_result) subgraph.update(batch_result)
async def get_subgraph(self, query): async def get_subgraph(self, query):
@ -319,9 +349,10 @@ class Query:
Get subgraph by extracting concepts, finding entities, and traversing. Get subgraph by extracting concepts, finding entities, and traversing.
Returns: Returns:
tuple: (subgraph, entities, concepts) where subgraph is a list of tuple: (subgraph, term_map, entities, concepts) where subgraph is
(s, p, o) tuples, entities is the seed entity list, and concepts a list of (s, p, o) string tuples, term_map maps each string
is the extracted concept list. tuple to its original (Term, Term, Term), entities is the seed
entity list, and concepts is the extracted concept list.
""" """
entities, concepts = await self.get_entities(query) entities, concepts = await self.get_entities(query)
@ -330,9 +361,9 @@ class Query:
logger.debug("Getting subgraph...") logger.debug("Getting subgraph...")
# Use optimized batch traversal instead of sequential processing # Use optimized batch traversal instead of sequential processing
subgraph = await self.follow_edges_batch(entities, self.max_path_length) subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
return list(subgraph), entities, concepts return list(subgraph), term_map, entities, concepts
async def resolve_labels_batch(self, entities): async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel""" """Resolve labels for multiple entities in parallel"""
@ -353,7 +384,7 @@ class Query:
- entities: list of seed entity URI strings - entities: list of seed entity URI strings
- concepts: list of concept strings extracted from query - concepts: list of concept strings extracted from query
""" """
subgraph, entities, concepts = await self.get_subgraph(query) subgraph, term_map, entities, concepts = await self.get_subgraph(query)
# Filter out label triples # Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
@ -377,7 +408,7 @@ class Query:
# Apply labels to subgraph and build URI mapping # Apply labels to subgraph and build URI mapping
labeled_edges = [] labeled_edges = []
uri_map = {} # Maps edge_id of labeled edge -> original URI triple uri_map = {} # Maps edge_id of labeled edge -> original Term triple
for s, p, o in filtered_subgraph: for s, p, o in filtered_subgraph:
labeled_triple = ( labeled_triple = (
@ -387,9 +418,9 @@ class Query:
) )
labeled_edges.append(labeled_triple) labeled_edges.append(labeled_triple)
# Map from labeled edge ID to original URIs # Map from labeled edge ID to original Terms (preserving types)
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2]) labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
uri_map[labeled_eid] = (s, p, o) uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
labeled_edges = labeled_edges[0:self.max_subgraph_size] labeled_edges = labeled_edges[0:self.max_subgraph_size]
@ -419,12 +450,14 @@ class Query:
# Step 1: Find subgraphs containing these edges via tg:contains # Step 1: Find subgraphs containing these edges via tg:contains
subgraph_tasks = [] subgraph_tasks = []
for s, p, o in edge_uris: for s, p, o in edge_uris:
# s, p, o may be Term objects (preserving types) or strings
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
quoted = Term( quoted = Term(
type=TRIPLE, type=TRIPLE,
triple=SchemaTriple( triple=SchemaTriple(
s=Term(type=IRI, iri=s), s=s_term, p=p_term, o=o_term,
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o),
) )
) )
subgraph_tasks.append( subgraph_tasks.append(