diff --git a/dev-tools/explainable-ai/README.md b/dev-tools/explainable-ai/README.md new file mode 100644 index 00000000..0eb7b21c --- /dev/null +++ b/dev-tools/explainable-ai/README.md @@ -0,0 +1,29 @@ +# Explainable AI Demo + +Demonstrates the TrustGraph streaming agent API with inline explainability +events. Sends an agent query, receives streaming thinking/observation/answer +chunks alongside RDF provenance events, then resolves the full provenance +chain from answer back to source documents. + +## What it shows + +- Streaming agent responses (thinking, observation, answer) +- Inline explainability events with RDF triples (W3C PROV + TrustGraph namespace) +- Label resolution for entity and predicate URIs +- Provenance chain traversal: subgraph → chunk → page → document +- Source text retrieval from the librarian using chunk IDs + +## Prerequisites + +A running TrustGraph instance with at least one loaded document and a +running flow. The default configuration connects to `ws://localhost:8088`. + +## Usage + +```bash +npm install +node index.js +``` + +Edit the `QUESTION` and `SOCKET_URL` constants at the top of `index.js` +to change the query or target instance. diff --git a/dev-tools/explainable-ai/index.js b/dev-tools/explainable-ai/index.js new file mode 100644 index 00000000..db0fc016 --- /dev/null +++ b/dev-tools/explainable-ai/index.js @@ -0,0 +1,552 @@ + +// ============================================================================ +// TrustGraph Explainability API Demo +// ============================================================================ +// +// This example demonstrates how to use the TrustGraph streaming agent API +// with explainability events. It shows how to: +// +// 1. Send an agent query and receive streaming thinking/observation/answer +// 2. Receive and parse explainability events as they arrive +// 3. Resolve the provenance chain for knowledge graph edges: +// subgraph -> chunk -> page -> document +// 4. Fetch source text from the librarian using chunk IDs +// +// Explainability events use RDF triples (W3C PROV ontology + TrustGraph +// namespace) to describe the retrieval pipeline. The key event types are: +// +// - AgentQuestion: The initial user query +// - Analysis/ToolUse: Agent deciding which tool to invoke +// - GraphRagQuestion: A sub-query sent to the Graph RAG pipeline +// - Grounding: Concepts extracted from the query for graph traversal +// - Exploration: Entities discovered during knowledge graph traversal +// - Focus: The selected knowledge graph edges (triples) used for context +// - Synthesis: The RAG answer synthesised from retrieved context +// - Observation: The tool result returned to the agent +// - Conclusion/Answer: The agent's final answer +// +// Each event carries RDF triples that link back through the provenance chain, +// allowing full traceability from answer back to source documents. +// ============================================================================ + +import { createTrustGraphSocket } from '@trustgraph/client'; + +// --------------------------------------------------------------------------- +// Configuration +// --------------------------------------------------------------------------- + +const USER = "trustgraph"; + +// Simple question +const QUESTION = "Tell me about the author of the document"; + +// Likely to trigger the deep research plan-and-execute pattern +//const QUESTION = "Do deep research and explain the risks posed globalisation in the modern world"; + +const SOCKET_URL = "ws://localhost:8088/api/v1/socket"; + +// --------------------------------------------------------------------------- +// RDF predicates and TrustGraph namespace constants +// --------------------------------------------------------------------------- + +const RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"; +const RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"; +const PROV_DERIVED = "http://www.w3.org/ns/prov#wasDerivedFrom"; + +const TG_GROUNDING = "https://trustgraph.ai/ns/Grounding"; +const TG_CONCEPT = "https://trustgraph.ai/ns/concept"; +const TG_EXPLORATION = "https://trustgraph.ai/ns/Exploration"; +const TG_ENTITY = "https://trustgraph.ai/ns/entity"; +const TG_FOCUS = "https://trustgraph.ai/ns/Focus"; +const TG_EDGE = "https://trustgraph.ai/ns/edge"; +const TG_CONTAINS = "https://trustgraph.ai/ns/contains"; + +// --------------------------------------------------------------------------- +// Utility: check whether a set of triples assigns a given RDF type to an ID +// --------------------------------------------------------------------------- + +const isType = (triples, id, type) => + triples.some(t => t.s.i === id && t.p.i === RDF_TYPE && t.o.i === type); + +// --------------------------------------------------------------------------- +// Utility: word-wrap text for display +// --------------------------------------------------------------------------- + +const wrapText = (text, width, indent, maxLines) => { + const clean = text.replace(/\s+/g, " ").trim(); + const lines = []; + let remaining = clean; + while (remaining.length > 0 && lines.length < maxLines) { + if (remaining.length <= width) { + lines.push(remaining); + break; + } + let breakAt = remaining.lastIndexOf(" ", width); + if (breakAt <= 0) breakAt = width; + lines.push(remaining.substring(0, breakAt)); + remaining = remaining.substring(breakAt).trimStart(); + } + if (remaining.length > 0 && lines.length >= maxLines) + lines[lines.length - 1] += " ..."; + return lines.map(l => indent + l).join("\n"); +}; + +// --------------------------------------------------------------------------- +// Connect to TrustGraph +// --------------------------------------------------------------------------- + +console.log("=".repeat(80)); +console.log("TrustGraph Explainability API Demo"); +console.log("=".repeat(80)); +console.log(`Connecting to: ${SOCKET_URL}`); +console.log(`Question: ${QUESTION}`); +console.log("=".repeat(80)); + +const client = createTrustGraphSocket(USER, undefined, SOCKET_URL); + +console.log("Connected, sending query...\n"); + +// Get a flow handle. Flows provide access to AI operations (agent, RAG, +// text completion, etc.) as well as knowledge graph queries. +const flow = client.flow("default"); + +// Get a librarian handle for fetching source document text. +const librarian = client.librarian(); + +// --------------------------------------------------------------------------- +// Inline explain event printing +// --------------------------------------------------------------------------- +// Explain events arrive during streaming alongside thinking/observation/ +// answer chunks. We print a summary immediately and store them for +// post-processing (label resolution and provenance lookups require async +// queries that can't run inside the synchronous callback). + +const explainEvents = []; + +const printExplainInline = (explainEvent) => { + const { explainId, explainTriples } = explainEvent; + if (!explainTriples) return; + + // Extract the RDF types assigned to the explain event's own ID. + // Every explain event has rdf:type triples that identify what kind + // of pipeline step it represents (Grounding, Exploration, Focus, etc.) + const types = explainTriples + .filter(t => t.s.i === explainId && t.p.i === RDF_TYPE) + .map(t => t.o.i); + + // Show short type names (e.g. "Grounding" instead of full URI) + const shortTypes = types + .map(t => t.split("/").pop().split("#").pop()) + .join(", "); + console.log(` [explain] ${shortTypes}`); + + // Grounding events contain the concepts extracted from the query. + // These are the seed terms used to begin knowledge graph traversal. + if (isType(explainTriples, explainId, TG_GROUNDING)) { + const concepts = explainTriples + .filter(t => t.s.i === explainId && t.p.i === TG_CONCEPT) + .map(t => t.o.v); + console.log(` Grounding concepts: ${concepts.join(", ")}`); + } + + // Exploration events list the entities found during graph traversal. + // We show the count here; labelled names are printed after resolution. + if (isType(explainTriples, explainId, TG_EXPLORATION)) { + const count = explainTriples + .filter(t => t.s.i === explainId && t.p.i === TG_ENTITY).length; + console.log(` Entities: ${count} found (see below)`); + } +}; + +const collectExplain = (explainEvent) => { + printExplainInline(explainEvent); + explainEvents.push(explainEvent); +}; + +// --------------------------------------------------------------------------- +// Label resolution +// --------------------------------------------------------------------------- +// Many explain triples reference entities and predicates by URI. We query +// the knowledge graph for rdfs:label to get human-readable names. + +const resolveLabels = async (uris) => { + const labels = new Map(); + await Promise.all(uris.map(async (uri) => { + try { + const results = await flow.triplesQuery( + { t: "i", i: uri }, + { t: "i", i: RDFS_LABEL }, + ); + if (results.length > 0) { + labels.set(uri, results[0].o.v); + } + } catch (e) { + // No label found, fall back to URI + } + })); + return labels; +}; + +// --------------------------------------------------------------------------- +// Provenance resolution for knowledge graph edges +// --------------------------------------------------------------------------- +// Focus events contain the knowledge graph triples (edges) that were selected +// as context for the RAG answer. Each edge can be traced back through the +// provenance chain to the original source document: +// +// subgraph --contains--> <> (RDF-star triple term) +// subgraph --wasDerivedFrom--> chunk (text chunk) +// chunk --wasDerivedFrom--> page (document page) +// page --wasDerivedFrom--> document (original document) +// +// The chunk URI also serves as the content ID in the librarian, so it can +// be used to fetch the actual source text. + +const resolveEdgeSources = async (edgeTriples) => { + const iri = (uri) => ({ t: "i", i: uri }); + const sources = new Map(); + + await Promise.all(edgeTriples.map(async (tr) => { + const key = JSON.stringify(tr); + try { + // Step 1: Find the subgraph that contains this edge triple. + // The query uses an RDF-star triple term as the object: the + // knowledge graph stores subgraph -> contains -> <>. + const subgraphResults = await flow.triplesQuery( + undefined, + iri(TG_CONTAINS), + { t: "t", tr }, + ); + if (subgraphResults.length === 0) { + if (tr.o.t === "l" || tr.o.t === "i") { + console.log(` No source match for triple:`); + console.log(` s: ${tr.s.i}`); + console.log(` p: ${tr.p.i}`); + console.log(` o: ${JSON.stringify(tr.o)}`); + } + return; + } + const subgraph = subgraphResults[0].s.i; + + // Step 2: Walk wasDerivedFrom chain: subgraph -> chunk + const chunkResults = await flow.triplesQuery( + iri(subgraph), iri(PROV_DERIVED), + ); + if (chunkResults.length === 0) { + sources.set(key, { subgraph }); + return; + } + const chunk = chunkResults[0].o.i; + + // Step 3: chunk -> page + const pageResults = await flow.triplesQuery( + iri(chunk), iri(PROV_DERIVED), + ); + if (pageResults.length === 0) { + sources.set(key, { subgraph, chunk }); + return; + } + const page = pageResults[0].o.i; + + // Step 4: page -> document + const docResults = await flow.triplesQuery( + iri(page), iri(PROV_DERIVED), + ); + const document = docResults.length > 0 ? docResults[0].o.i : undefined; + + sources.set(key, { subgraph, chunk, page, document }); + } catch (e) { + // Query failed, skip this edge + } + })); + + return sources; +}; + +// --------------------------------------------------------------------------- +// Collect URIs that need label resolution +// --------------------------------------------------------------------------- +// Scans explain events for entity URIs (from Exploration events) and edge +// term URIs (from Focus events) so we can batch-resolve their labels. + +const collectUris = (events) => { + const uris = new Set(); + for (const { explainId, explainTriples } of events) { + if (!explainTriples) continue; + + // Entity URIs from exploration + if (isType(explainTriples, explainId, TG_EXPLORATION)) { + for (const t of explainTriples) { + if (t.s.i === explainId && t.p.i === TG_ENTITY) + uris.add(t.o.i); + } + } + + // Subject, predicate, and object URIs from focus edge triples + if (isType(explainTriples, explainId, TG_FOCUS)) { + for (const t of explainTriples) { + if (t.p.i === TG_EDGE && t.o.t === "t") { + const tr = t.o.tr; + if (tr.s.t === "i") uris.add(tr.s.i); + if (tr.p.t === "i") uris.add(tr.p.i); + if (tr.o.t === "i") uris.add(tr.o.i); + } + } + } + } + return uris; +}; + +// --------------------------------------------------------------------------- +// Collect edge triples from Focus events +// --------------------------------------------------------------------------- +// Focus events contain selectedEdge -> edge relationships. Each edge's +// object is an RDF-star triple term ({t: "t", tr: {s, p, o}}) representing +// the actual knowledge graph triple used as RAG context. + +const collectEdgeTriples = (events) => { + const edges = []; + for (const { explainId, explainTriples } of events) { + if (!explainTriples) continue; + if (isType(explainTriples, explainId, TG_FOCUS)) { + for (const t of explainTriples) { + if (t.p.i === TG_EDGE && t.o.t === "t") + edges.push(t.o.tr); + } + } + } + return edges; +}; + +// --------------------------------------------------------------------------- +// Print knowledge graph edges with provenance +// --------------------------------------------------------------------------- +// Displays each edge triple with resolved labels and its source location +// (chunk -> page -> document). + +const printFocusEdges = (events, labels, edgeSources) => { + const label = (uri) => labels.get(uri) || uri; + + for (const { explainId, explainTriples } of events) { + if (!explainTriples) continue; + if (!isType(explainTriples, explainId, TG_FOCUS)) continue; + + const termValue = (term) => + term.t === "i" ? label(term.i) : (term.v || "?"); + + const edges = explainTriples + .filter(t => t.p.i === TG_EDGE && t.o.t === "t") + .map(t => t.o.tr); + + const display = edges.slice(0, 20); + for (const tr of display) { + console.log(` ${termValue(tr.s)} -> ${termValue(tr.p)} -> ${termValue(tr.o)}`); + const src = edgeSources.get(JSON.stringify(tr)); + if (src) { + const parts = []; + if (src.chunk) parts.push(label(src.chunk)); + if (src.page) parts.push(label(src.page)); + if (src.document) parts.push(label(src.document)); + if (parts.length > 0) + console.log(` Source: ${parts.join(" -> ")}`); + } + } + if (edges.length > 20) + console.log(` ... and ${edges.length - 20} more`); + } +}; + +// --------------------------------------------------------------------------- +// Fetch chunk text from the librarian +// --------------------------------------------------------------------------- +// The chunk URI (e.g. urn:chunk:UUID) serves as a universal ID that ties +// together provenance metadata, embeddings, and the source text content. +// The librarian stores the original text keyed by this same URI, so we +// can retrieve it with streamDocument(chunkUri). + +const fetchChunkText = (chunkUri) => { + return new Promise((resolve, reject) => { + let text = ""; + librarian.streamDocument( + chunkUri, + (content, chunkIndex, totalChunks, complete) => { + text += content; + if (complete) resolve(text); + }, + (error) => reject(error), + ); + }); +}; + +// =========================================================================== +// Send the agent query +// =========================================================================== +// The agent callback receives four types of streaming content: +// - think: the agent's reasoning (chain-of-thought) +// - observe: tool results returned to the agent +// - answer: the final answer being generated +// - error: any errors during processing +// +// The onExplain callback fires for each explainability event, delivering +// RDF triples that describe what happened at each pipeline stage. + +let thought = ""; +let obs = ""; +let ans = ""; + +await flow.agent( + + QUESTION, + + // Think callback: agent reasoning / chain-of-thought + (chunk, complete, messageId, metadata) => { + thought += chunk; + if (complete) { + console.log("\nThinking:", thought, "\n"); + thought = ""; + } + }, + + // Observe callback: tool results returned to the agent + (chunk, complete, messageId, metadata) => { + obs += chunk; + if (complete) { + console.log("\nObservation:", obs, "\n"); + obs = ""; + } + }, + + // Answer callback: the agent's final response + (chunk, complete, messageId, metadata) => { + ans += chunk; + if (complete) { + console.log("\nAnswer:", ans, "\n"); + ans = ""; + } + }, + + // Error callback + (error) => { + console.log(JSON.stringify({ type: "error", error }, null, 2)); + }, + + // Explain callback: explainability events with RDF triples + (explainEvent) => { + collectExplain(explainEvent); + } + +); + +// =========================================================================== +// Post-processing: resolve labels, provenance, and source text +// =========================================================================== +// After the agent query completes, we have all the explain events. Now we +// can make async queries to: +// 1. Trace each edge back to its source document (provenance chain) +// 2. Resolve URIs to human-readable labels +// 3. Fetch the original text for each source chunk + +console.log("Resolving provenance...\n"); + +// Resolve the provenance chain for each knowledge graph edge +const edgeTriples = collectEdgeTriples(explainEvents); +const edgeSources = await resolveEdgeSources(edgeTriples); + +// Collect all URIs that need labels: entities, edge terms, and source URIs +const uris = collectUris(explainEvents); +for (const src of edgeSources.values()) { + if (src.chunk) uris.add(src.chunk); + if (src.page) uris.add(src.page); + if (src.document) uris.add(src.document); +} +const labels = await resolveLabels([...uris]); + +const label = (uri) => labels.get(uri) || uri; + +// --------------------------------------------------------------------------- +// Display: Entities retrieved during graph exploration +// --------------------------------------------------------------------------- + +for (const { explainId, explainTriples } of explainEvents) { + if (!explainTriples) continue; + if (!isType(explainTriples, explainId, TG_EXPLORATION)) continue; + const entities = explainTriples + .filter(t => t.s.i === explainId && t.p.i === TG_ENTITY) + .map(t => label(t.o.i)); + const display = entities.slice(0, 10); + console.log("=".repeat(80)); + console.log("Entities Retrieved"); + console.log("=".repeat(80)); + console.log(` ${entities.length} entities: ${display.join(", ")}${entities.length > 10 ? ", ..." : ""}`); +} + +// --------------------------------------------------------------------------- +// Display: Knowledge graph edges with provenance +// --------------------------------------------------------------------------- + +console.log("\n" + "=".repeat(80)); +console.log("Knowledge Graph Edges"); +console.log("=".repeat(80)); +printFocusEdges(explainEvents, labels, edgeSources); + +// --------------------------------------------------------------------------- +// Display: Source text for each chunk referenced by the edges +// --------------------------------------------------------------------------- + +const uniqueChunks = new Set(); +for (const src of edgeSources.values()) { + if (src.chunk) uniqueChunks.add(src.chunk); +} + +console.log(`\nFetching text for ${uniqueChunks.size} source chunks...`); +const chunkTexts = new Map(); +await Promise.all([...uniqueChunks].map(async (chunkUri) => { + try { + // streamDocument returns base64-encoded content + const text = await fetchChunkText(chunkUri); + chunkTexts.set(chunkUri, text); + } catch (e) { + // Failed to fetch text for this chunk + } +})); + +console.log("\n" + "=".repeat(80)); +console.log("Sources"); +console.log("=".repeat(80)); + +let sourceIndex = 0; +for (const chunkUri of uniqueChunks) { + sourceIndex++; + const chunkLabel = labels.get(chunkUri) || chunkUri; + + // Find the page and document labels for this chunk + let pageLabel, docLabel; + for (const src of edgeSources.values()) { + if (src.chunk === chunkUri) { + if (src.page) pageLabel = labels.get(src.page) || src.page; + if (src.document) docLabel = labels.get(src.document) || src.document; + break; + } + } + + console.log(`\n [${sourceIndex}] ${docLabel || "?"} / ${pageLabel || "?"} / ${chunkLabel}`); + console.log(" " + "-".repeat(70)); + + // Decode the base64 content and display a wrapped snippet + const b64 = chunkTexts.get(chunkUri); + if (b64) { + const text = Buffer.from(b64, "base64").toString("utf-8"); + console.log(wrapText(text, 76, " ", 6)); + } +} + +// --------------------------------------------------------------------------- +// Clean up +// --------------------------------------------------------------------------- + +console.log("\n" + "=".repeat(80)); +console.log("Query complete"); +console.log("=".repeat(80)); + +client.close(); +process.exit(0); diff --git a/dev-tools/explainable-ai/package.json b/dev-tools/explainable-ai/package.json new file mode 100644 index 00000000..cd96584e --- /dev/null +++ b/dev-tools/explainable-ai/package.json @@ -0,0 +1,13 @@ +{ + "name": "explain-api-example", + "version": "1.0.0", + "description": "TrustGraph explainability API example", + "main": "index.js", + "type": "module", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "dependencies": { + "@trustgraph/client": "^1.7.2" + } +} diff --git a/tests/unit/test_agent/test_orchestrator_provenance_integration.py b/tests/unit/test_agent/test_orchestrator_provenance_integration.py new file mode 100644 index 00000000..96d41259 --- /dev/null +++ b/tests/unit/test_agent/test_orchestrator_provenance_integration.py @@ -0,0 +1,655 @@ +""" +Integration tests for agent-orchestrator provenance chains. + +Tests all three patterns by calling iterate() with mocked dependencies +and verifying the explain events emitted via respond(). + +Provenance chains: + React: session → iteration → (observation or final) + Plan: session → plan → step-result(s) → synthesis + Supervisor: session → decomposition → finding(s) → synthesis +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from dataclasses import dataclass, field + +from trustgraph.schema import ( + AgentRequest, AgentResponse, AgentStep, PlanStep, +) + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + GRAPH_RETRIEVAL, +) + +# Agent provenance type constants +from trustgraph.provenance.namespaces import ( + TG_AGENT_QUESTION, + TG_ANALYSIS, + TG_TOOL_USE, + TG_OBSERVATION_TYPE, + TG_CONCLUSION, + TG_DECOMPOSITION, + TG_FINDING, + TG_PLAN_TYPE, + TG_STEP_RESULT, + TG_SYNTHESIS as TG_AGENT_SYNTHESIS, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +def collect_explain_events(respond_mock): + """Extract explain events from a respond mock's call history.""" + events = [] + for call in respond_mock.call_args_list: + resp = call[0][0] + if isinstance(resp, AgentResponse) and resp.chunk_type == "explain": + events.append({ + "explain_id": resp.explain_id, + "explain_graph": resp.explain_graph, + "triples": resp.explain_triples, + }) + return events + + +# --------------------------------------------------------------------------- +# Mock processor +# --------------------------------------------------------------------------- + +def make_mock_processor(tools=None): + """Build a mock processor with the minimal interface patterns need.""" + processor = MagicMock() + processor.max_iterations = 10 + processor.save_answer_content = AsyncMock() + + # provenance_session_uri must return a real URI + def mock_session_uri(session_id): + return f"urn:trustgraph:agent:session:{session_id}" + processor.provenance_session_uri.side_effect = mock_session_uri + + # Agent with tools + agent = MagicMock() + agent.tools = tools or {} + agent.additional_context = "" + processor.agent = agent + + # Aggregator for supervisor + processor.aggregator = MagicMock() + + return processor + + +def make_mock_flow(): + """Build a mock flow that returns async mock producers.""" + producers = {} + + def flow_factory(name): + if name not in producers: + producers[name] = AsyncMock() + return producers[name] + + flow = MagicMock(side_effect=flow_factory) + flow._producers = producers + return flow + + +def make_base_request(**kwargs): + """Build a minimal AgentRequest.""" + defaults = dict( + question="What is quantum computing?", + state="", + group=[], + history=[], + user="testuser", + collection="default", + streaming=False, + session_id="test-session-123", + conversation_id="", + pattern="react", + task_type="", + framing="", + correlation_id="", + parent_session_id="", + subagent_goal="", + expected_siblings=0, + ) + defaults.update(kwargs) + return AgentRequest(**defaults) + + +# --------------------------------------------------------------------------- +# React pattern tests +# --------------------------------------------------------------------------- + +class TestReactPatternProvenance: + """ + React pattern chain: session → iteration → final + (single iteration ending in Final answer) + """ + + @pytest.mark.asyncio + async def test_single_iteration_final_answer(self): + """ + A single react iteration that produces a Final answer should emit: + session, iteration, final — in that order. + """ + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.react.types import Action, Final + + processor = make_mock_processor() + pattern = ReactPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + request = make_base_request() + + # Mock AgentManager.react to call on_action then return Final + with patch( + 'trustgraph.agent.orchestrator.react_pattern.AgentManager' + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + final = Final( + thought="I know the answer", + final="Quantum computing uses qubits.", + ) + + async def mock_react(question, history, think, observe, answer, + context, streaming, on_action): + # Simulate the on_action callback before returning Final + if on_action: + await on_action(Action( + thought="I know the answer", + name="final", + arguments={}, + observation="", + )) + return final + + mock_am.react.side_effect = mock_react + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have 3 events: session, iteration, final + assert len(events) == 3, ( + f"Expected 3 explain events (session, iteration, final), " + f"got {len(events)}: {[e['explain_id'] for e in events]}" + ) + + # Check types + assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION) + assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS) + assert has_type(events[2]["triples"], events[2]["explain_id"], TG_CONCLUSION) + + # Check derivation chain + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + + uris = [e["explain_id"] for e in events] + + # iteration derives from session + assert derived_from(all_triples, uris[1]) == uris[0] + # final derives from session (first iteration, no prior observation) + assert derived_from(all_triples, uris[2]) == uris[0] + + @pytest.mark.asyncio + async def test_iteration_with_tool_call(self): + """ + A react iteration that calls a tool (not Final) should emit: + session, iteration, observation — then call next() for continuation. + """ + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.react.types import Action + + # Create a mock tool + mock_tool = MagicMock() + mock_tool.name = "knowledge-query" + mock_tool.description = "Query the knowledge base" + mock_tool.arguments = [] + mock_tool.groups = [] + mock_tool.states = {} + mock_tool_impl = AsyncMock(return_value="The answer is 42") + mock_tool.implementation = MagicMock(return_value=mock_tool_impl) + + processor = make_mock_processor( + tools={"knowledge-query": mock_tool} + ) + pattern = ReactPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + request = make_base_request() + + action = Action( + thought="I need to look this up", + name="knowledge-query", + arguments={"question": "What is quantum computing?"}, + observation="Quantum computing uses qubits.", + ) + + with patch( + 'trustgraph.agent.orchestrator.react_pattern.AgentManager' + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + async def mock_react(question, history, think, observe, answer, + context, streaming, on_action): + if on_action: + await on_action(action) + return action + + mock_am.react.side_effect = mock_react + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have 3 events: session, iteration, observation + assert len(events) == 3, ( + f"Expected 3 explain events (session, iteration, observation), " + f"got {len(events)}: {[e['explain_id'] for e in events]}" + ) + + assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION) + assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS) + assert has_type(events[2]["triples"], events[2]["explain_id"], TG_OBSERVATION_TYPE) + + # next() should have been called to continue the loop + assert next_fn.called + + @pytest.mark.asyncio + async def test_all_triples_in_retrieval_graph(self): + """All explain triples should be in urn:graph:retrieval.""" + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.react.types import Action, Final + + processor = make_mock_processor() + pattern = ReactPattern(processor) + respond = AsyncMock() + flow = make_mock_flow() + + with patch( + 'trustgraph.agent.orchestrator.react_pattern.AgentManager' + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + async def mock_react(question, history, think, observe, answer, + context, streaming, on_action): + if on_action: + await on_action(Action( + thought="done", name="final", + arguments={}, observation="", + )) + return Final(thought="done", final="answer") + + mock_am.react.side_effect = mock_react + await pattern.iterate( + make_base_request(), respond, AsyncMock(), flow, + ) + + for event in collect_explain_events(respond): + for t in event["triples"]: + assert t.g == GRAPH_RETRIEVAL + + +# --------------------------------------------------------------------------- +# Plan-then-execute pattern tests +# --------------------------------------------------------------------------- + +class TestPlanPatternProvenance: + """ + Plan pattern chain: + Planning iteration: session → plan + Execution iterations: step-result(s) → synthesis + """ + + @pytest.mark.asyncio + async def test_planning_iteration_emits_session_and_plan(self): + """ + The first iteration (planning) should emit: + session, plan — then call next() with the plan in history. + """ + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + + processor = make_mock_processor() + pattern = PlanThenExecutePattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt client for plan creation + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = [ + {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []}, + {"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]}, + ] + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + request = make_base_request(pattern="plan") + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have 2 events: session, plan + assert len(events) == 2, ( + f"Expected 2 explain events (session, plan), " + f"got {len(events)}: {[e['explain_id'] for e in events]}" + ) + + assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION) + assert has_type(events[1]["triples"], events[1]["explain_id"], TG_PLAN_TYPE) + + # Plan should derive from session + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"] + + # next() should have been called with plan in history + assert next_fn.called + + @pytest.mark.asyncio + async def test_execution_iteration_emits_step_result(self): + """ + An execution iteration should emit a step-result event. + """ + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + + # Create a mock tool + mock_tool = MagicMock() + mock_tool.name = "knowledge-query" + mock_tool.description = "Query KB" + mock_tool.arguments = [] + mock_tool.groups = [] + mock_tool.states = {} + mock_tool_impl = AsyncMock(return_value="Found the answer") + mock_tool.implementation = MagicMock(return_value=mock_tool_impl) + + processor = make_mock_processor( + tools={"knowledge-query": mock_tool} + ) + pattern = PlanThenExecutePattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt for step execution + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = { + "tool": "knowledge-query", + "arguments": {"question": "quantum computing"}, + } + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + # Request with plan already in history (second iteration) + plan_step = AgentStep( + thought="Created plan", + action="plan", + arguments={}, + observation="[]", + step_type="plan", + plan=[ + PlanStep(goal="Find info", tool_hint="knowledge-query", + depends_on=[], status="pending", result=""), + ], + ) + request = make_base_request( + pattern="plan", + history=[plan_step], + ) + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have step-result (no session on iteration > 1) + step_events = [ + e for e in events + if has_type(e["triples"], e["explain_id"], TG_STEP_RESULT) + ] + assert len(step_events) == 1, ( + f"Expected 1 step-result event, got {len(step_events)}" + ) + + @pytest.mark.asyncio + async def test_synthesis_after_all_steps_complete(self): + """ + When all plan steps are completed, synthesis should be emitted. + """ + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + + processor = make_mock_processor() + pattern = PlanThenExecutePattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt for synthesis + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = "The synthesised answer." + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + # Request with all steps completed + exec_step = AgentStep( + thought="Executing step", + action="knowledge-query", + arguments={}, + observation="Result", + step_type="execute", + plan=[ + PlanStep(goal="Find info", tool_hint="knowledge-query", + depends_on=[], status="completed", result="Found it"), + ], + ) + request = make_base_request( + pattern="plan", + history=[exec_step], + ) + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have synthesis event + synth_events = [ + e for e in events + if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS) + ] + assert len(synth_events) == 1, ( + f"Expected 1 synthesis event, got {len(synth_events)}" + ) + + +# --------------------------------------------------------------------------- +# Supervisor pattern tests +# --------------------------------------------------------------------------- + +class TestSupervisorPatternProvenance: + """ + Supervisor pattern chain: + Decompose: session → decomposition + (Fan-out to subagents happens externally) + Synthesise: synthesis (derives from findings) + """ + + @pytest.mark.asyncio + async def test_decompose_emits_session_and_decomposition(self): + """ + The decompose phase should emit: session, decomposition. + """ + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + processor = make_mock_processor() + pattern = SupervisorPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt for decomposition + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = [ + "What is quantum computing?", + "What are qubits?", + ] + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + request = make_base_request(pattern="supervisor") + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have 2 events: session, decomposition + assert len(events) == 2, ( + f"Expected 2 explain events (session, decomposition), " + f"got {len(events)}: {[e['explain_id'] for e in events]}" + ) + + assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION) + assert has_type(events[1]["triples"], events[1]["explain_id"], TG_DECOMPOSITION) + + # Decomposition derives from session + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"] + + @pytest.mark.asyncio + async def test_synthesis_emits_after_subagent_results(self): + """ + When subagent results arrive, synthesis should be emitted. + """ + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + processor = make_mock_processor() + pattern = SupervisorPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + # Mock prompt for synthesis + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = "The combined answer." + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + # Request with subagent results in history + synth_step = AgentStep( + thought="", + action="synthesise", + arguments={}, + observation="", + step_type="synthesise", + subagent_results={ + "What is quantum computing?": "It uses qubits", + "What are qubits?": "Quantum bits", + }, + ) + request = make_base_request( + pattern="supervisor", + history=[synth_step], + ) + + await pattern.iterate(request, respond, next_fn, flow) + + events = collect_explain_events(respond) + + # Should have synthesis event (no session on iteration > 1) + synth_events = [ + e for e in events + if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS) + ] + assert len(synth_events) == 1 + + @pytest.mark.asyncio + async def test_decompose_fans_out_subagents(self): + """The decompose phase should call next() for each subagent goal.""" + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + processor = make_mock_processor() + pattern = SupervisorPattern(processor) + + respond = AsyncMock() + next_fn = AsyncMock() + flow = make_mock_flow() + + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"] + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + request = make_base_request(pattern="supervisor") + + await pattern.iterate(request, respond, next_fn, flow) + + # 3 subagent requests fanned out + assert next_fn.call_count == 3 diff --git a/tests/unit/test_provenance/test_graph_rag_chain.py b/tests/unit/test_provenance/test_graph_rag_chain.py new file mode 100644 index 00000000..657384b0 --- /dev/null +++ b/tests/unit/test_provenance/test_graph_rag_chain.py @@ -0,0 +1,295 @@ +""" +Structural test for the graph-rag provenance chain. + +Verifies that a complete graph-rag query produces the expected +provenance chain: + + question → grounding → exploration → focus → synthesis + +Each step must: +- Have the correct rdf:type +- Link to its predecessor via prov:wasDerivedFrom +- Carry expected domain-specific data +""" + +import pytest + +from trustgraph.provenance.triples import ( + question_triples, + grounding_triples, + exploration_triples, + focus_triples, + synthesis_triples, +) +from trustgraph.provenance.uris import ( + question_uri, + grounding_uri, + exploration_uri, + focus_uri, + synthesis_uri, +) +from trustgraph.provenance.namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ENTITY, PROV_WAS_DERIVED_FROM, + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_GRAPH_RAG_QUESTION, TG_ANSWER_TYPE, + TG_QUERY, TG_CONCEPT, TG_ENTITY, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_DOCUMENT, + PROV_STARTED_AT_TIME, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +SESSION_ID = "test-session-1234" + + +def find_triple(triples, predicate, subject=None): + """Find first triple matching predicate (and optionally subject).""" + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + """Find all triples matching predicate (and optionally subject).""" + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + """Check if subject has the given rdf:type.""" + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + """Get the wasDerivedFrom target URI for a subject.""" + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +# --------------------------------------------------------------------------- +# Build the full chain +# --------------------------------------------------------------------------- + +@pytest.fixture +def chain(): + """Build all provenance triples for a complete graph-rag query.""" + q_uri = question_uri(SESSION_ID) + gnd_uri = grounding_uri(SESSION_ID) + exp_uri = exploration_uri(SESSION_ID) + foc_uri = focus_uri(SESSION_ID) + syn_uri = synthesis_uri(SESSION_ID) + + q = question_triples(q_uri, "What is quantum computing?", "2026-01-01T00:00:00Z") + gnd = grounding_triples(gnd_uri, q_uri, ["quantum", "computing"]) + exp = exploration_triples( + exp_uri, gnd_uri, edge_count=42, + entities=["urn:entity:1", "urn:entity:2"], + ) + foc = focus_triples( + foc_uri, exp_uri, + selected_edges_with_reasoning=[ + { + "edge": ( + "http://example.com/QuantumComputing", + "http://schema.org/relatedTo", + "http://example.com/Physics", + ), + "reasoning": "Directly relevant to the query", + }, + { + "edge": ( + "http://example.com/QuantumComputing", + "http://schema.org/name", + "Quantum Computing", + ), + "reasoning": "Provides the entity label", + }, + ], + session_id=SESSION_ID, + ) + syn = synthesis_triples(syn_uri, foc_uri, document_id="urn:doc:answer-1") + + return { + "uris": { + "question": q_uri, + "grounding": gnd_uri, + "exploration": exp_uri, + "focus": foc_uri, + "synthesis": syn_uri, + }, + "triples": { + "question": q, + "grounding": gnd, + "exploration": exp, + "focus": foc, + "synthesis": syn, + }, + "all": q + gnd + exp + foc + syn, + } + + +# --------------------------------------------------------------------------- +# Chain structure tests +# --------------------------------------------------------------------------- + +class TestGraphRagProvenanceChain: + """Verify the full question → grounding → exploration → focus → synthesis chain.""" + + def test_chain_has_five_stages(self, chain): + """Each stage should produce at least some triples.""" + for stage in ["question", "grounding", "exploration", "focus", "synthesis"]: + assert len(chain["triples"][stage]) > 0, f"{stage} produced no triples" + + def test_derivation_chain(self, chain): + """ + The wasDerivedFrom links must form: + grounding → question, exploration → grounding, + focus → exploration, synthesis → focus. + """ + uris = chain["uris"] + all_triples = chain["all"] + + assert derived_from(all_triples, uris["grounding"]) == uris["question"] + assert derived_from(all_triples, uris["exploration"]) == uris["grounding"] + assert derived_from(all_triples, uris["focus"]) == uris["exploration"] + assert derived_from(all_triples, uris["synthesis"]) == uris["focus"] + + def test_question_has_no_parent(self, chain): + """The root question should not derive from anything (no parent_uri).""" + uris = chain["uris"] + all_triples = chain["all"] + assert derived_from(all_triples, uris["question"]) is None + + def test_question_with_parent(self): + """When a parent_uri is given, question should derive from it.""" + q_uri = question_uri("child-session") + parent = "urn:trustgraph:agent:iteration:parent" + q = question_triples(q_uri, "sub-query", "2026-01-01T00:00:00Z", + parent_uri=parent) + assert derived_from(q, q_uri) == parent + + +# --------------------------------------------------------------------------- +# Type annotation tests +# --------------------------------------------------------------------------- + +class TestGraphRagProvenanceTypes: + """Each stage must have the correct rdf:type annotations.""" + + def test_question_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["question"] + assert has_type(triples, uris["question"], PROV_ENTITY) + assert has_type(triples, uris["question"], TG_GRAPH_RAG_QUESTION) + + def test_grounding_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["grounding"] + assert has_type(triples, uris["grounding"], PROV_ENTITY) + assert has_type(triples, uris["grounding"], TG_GROUNDING) + + def test_exploration_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["exploration"] + assert has_type(triples, uris["exploration"], PROV_ENTITY) + assert has_type(triples, uris["exploration"], TG_EXPLORATION) + + def test_focus_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["focus"] + assert has_type(triples, uris["focus"], PROV_ENTITY) + assert has_type(triples, uris["focus"], TG_FOCUS) + + def test_synthesis_types(self, chain): + uris = chain["uris"] + triples = chain["triples"]["synthesis"] + assert has_type(triples, uris["synthesis"], PROV_ENTITY) + assert has_type(triples, uris["synthesis"], TG_SYNTHESIS) + assert has_type(triples, uris["synthesis"], TG_ANSWER_TYPE) + + +# --------------------------------------------------------------------------- +# Domain-specific content tests +# --------------------------------------------------------------------------- + +class TestGraphRagProvenanceContent: + """Each stage should carry the expected domain data.""" + + def test_question_has_query_text(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["question"], TG_QUERY, uris["question"]) + assert t is not None + assert t.o.value == "What is quantum computing?" + + def test_question_has_timestamp(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["question"], PROV_STARTED_AT_TIME, uris["question"]) + assert t is not None + assert t.o.value == "2026-01-01T00:00:00Z" + + def test_grounding_has_concepts(self, chain): + uris = chain["uris"] + concepts = find_triples(chain["triples"]["grounding"], TG_CONCEPT, uris["grounding"]) + concept_values = {t.o.value for t in concepts} + assert concept_values == {"quantum", "computing"} + + def test_exploration_has_edge_count(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["exploration"], TG_EDGE_COUNT, uris["exploration"]) + assert t is not None + assert t.o.value == "42" + + def test_exploration_has_entities(self, chain): + uris = chain["uris"] + entities = find_triples(chain["triples"]["exploration"], TG_ENTITY, uris["exploration"]) + entity_iris = {t.o.iri for t in entities} + assert entity_iris == {"urn:entity:1", "urn:entity:2"} + + def test_focus_has_selected_edges(self, chain): + uris = chain["uris"] + edges = find_triples(chain["triples"]["focus"], TG_SELECTED_EDGE, uris["focus"]) + assert len(edges) == 2 + + def test_focus_edges_have_quoted_triples(self, chain): + """Each edge selection entity should have a tg:edge with a quoted triple.""" + focus = chain["triples"]["focus"] + edge_triples = find_triples(focus, TG_EDGE) + assert len(edge_triples) == 2 + + # Each should have a quoted triple as the object + for t in edge_triples: + assert t.o.triple is not None, "tg:edge object should be a quoted triple" + + def test_focus_edges_have_reasoning(self, chain): + """Each edge selection entity should have tg:reasoning.""" + focus = chain["triples"]["focus"] + reasoning = find_triples(focus, TG_REASONING) + assert len(reasoning) == 2 + reasoning_texts = {t.o.value for t in reasoning} + assert "Directly relevant to the query" in reasoning_texts + assert "Provides the entity label" in reasoning_texts + + def test_synthesis_has_document_ref(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["synthesis"], TG_DOCUMENT, uris["synthesis"]) + assert t is not None + assert t.o.iri == "urn:doc:answer-1" + + def test_synthesis_has_labels(self, chain): + uris = chain["uris"] + t = find_triple(chain["triples"]["synthesis"], RDFS_LABEL, uris["synthesis"]) + assert t is not None + assert t.o.value == "Synthesis" diff --git a/tests/unit/test_retrieval/test_document_rag_provenance_integration.py b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py new file mode 100644 index 00000000..74157285 --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py @@ -0,0 +1,380 @@ +""" +Integration test: run a full DocumentRag.query() with mocked subsidiary +clients and verify the explain_callback receives the complete provenance +chain in the correct order with correct structure. + +Document-RAG provenance chain (4 stages): + question → grounding → exploration → synthesis +""" + +import pytest +from unittest.mock import AsyncMock +from dataclasses import dataclass + +from trustgraph.retrieval.document_rag.document_rag import DocumentRag + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_SYNTHESIS, TG_ANSWER_TYPE, + TG_QUERY, TG_CONCEPT, + TG_CHUNK_COUNT, TG_SELECTED_CHUNK, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +@dataclass +class ChunkMatch: + """Mimics the result from doc_embeddings_client.query().""" + chunk_id: str + + +# --------------------------------------------------------------------------- +# Mock setup +# --------------------------------------------------------------------------- + +CHUNK_A = "urn:chunk:policy-doc-1:chunk-0" +CHUNK_B = "urn:chunk:policy-doc-1:chunk-1" +CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase." +CHUNK_B_CONTENT = "Refunds are processed to the original payment method." + + +def build_mock_clients(): + """ + Build mock clients for a document-rag query. + + Client call sequence during query(): + 1. prompt_client.prompt("extract-concepts", ...) -> concepts + 2. embeddings_client.embed(concepts) -> vectors + 3. doc_embeddings_client.query(vector, ...) -> chunk matches + 4. fetch_chunk(chunk_id, user) -> chunk content + 5. prompt_client.document_prompt(query, documents) -> answer + """ + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + doc_embeddings_client = AsyncMock() + fetch_chunk = AsyncMock() + + # 1. Concept extraction + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return "return policy\nrefund" + return "" + + prompt_client.prompt.side_effect = mock_prompt + + # 2. Embedding vectors + embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # 3. Chunk matching + doc_embeddings_client.query.return_value = [ + ChunkMatch(chunk_id=CHUNK_A), + ChunkMatch(chunk_id=CHUNK_B), + ] + + # 4. Chunk content + async def mock_fetch(chunk_id, user): + return { + CHUNK_A: CHUNK_A_CONTENT, + CHUNK_B: CHUNK_B_CONTENT, + }[chunk_id] + + fetch_chunk.side_effect = mock_fetch + + # 5. Synthesis + prompt_client.document_prompt.return_value = ( + "Items can be returned within 30 days for a full refund." + ) + + return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestDocumentRagQueryProvenance: + """ + Run a real DocumentRag.query() and verify the provenance chain emitted + via explain_callback. + """ + + @pytest.mark.asyncio + async def test_explain_callback_receives_four_events(self): + """query() should emit exactly 4 explain events.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + assert len(events) == 4, ( + f"Expected 4 explain events (question, grounding, exploration, " + f"synthesis), got {len(events)}" + ) + + @pytest.mark.asyncio + async def test_events_have_correct_types_in_order(self): + """ + Events should arrive as: + question, grounding, exploration, synthesis. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + expected_types = [ + TG_DOC_RAG_QUESTION, + TG_GROUNDING, + TG_EXPLORATION, + TG_SYNTHESIS, + ] + + for i, expected_type in enumerate(expected_types): + uri = events[i]["explain_id"] + triples = events[i]["triples"] + assert has_type(triples, uri, expected_type), ( + f"Event {i} (uri={uri}) should have type {expected_type}" + ) + + @pytest.mark.asyncio + async def test_derivation_chain_links_correctly(self): + """ + Each event's URI should link to the previous via wasDerivedFrom: + question → (none) + grounding → question + exploration → grounding + synthesis → exploration + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + uris = [e["explain_id"] for e in events] + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + + # question has no parent + assert derived_from(all_triples, uris[0]) is None + + # grounding → question + assert derived_from(all_triples, uris[1]) == uris[0] + + # exploration → grounding + assert derived_from(all_triples, uris[2]) == uris[1] + + # synthesis → exploration + assert derived_from(all_triples, uris[3]) == uris[2] + + @pytest.mark.asyncio + async def test_question_carries_query_text(self): + """The question event should contain the original query string.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + q_uri = events[0]["explain_id"] + q_triples = events[0]["triples"] + t = find_triple(q_triples, TG_QUERY, q_uri) + assert t is not None + assert t.o.value == "What is the return policy?" + + @pytest.mark.asyncio + async def test_grounding_carries_concepts(self): + """The grounding event should list extracted concepts.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + gnd_uri = events[1]["explain_id"] + gnd_triples = events[1]["triples"] + concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri) + concept_values = {t.o.value for t in concepts} + assert "return policy" in concept_values + assert "refund" in concept_values + + @pytest.mark.asyncio + async def test_exploration_has_chunk_count(self): + """The exploration event should report the number of chunks retrieved.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + exp_uri = events[2]["explain_id"] + exp_triples = events[2]["triples"] + t = find_triple(exp_triples, TG_CHUNK_COUNT, exp_uri) + assert t is not None + assert int(t.o.value) == 2 + + @pytest.mark.asyncio + async def test_exploration_has_selected_chunks(self): + """The exploration event should list the chunk IDs that were fetched.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + exp_uri = events[2]["explain_id"] + exp_triples = events[2]["triples"] + chunks = find_triples(exp_triples, TG_SELECTED_CHUNK, exp_uri) + chunk_iris = {t.o.iri for t in chunks} + assert CHUNK_A in chunk_iris + assert CHUNK_B in chunk_iris + + @pytest.mark.asyncio + async def test_synthesis_is_answer_type(self): + """The synthesis event should have tg:Synthesis and tg:Answer types.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + syn_uri = events[3]["explain_id"] + syn_triples = events[3]["triples"] + assert has_type(syn_triples, syn_uri, TG_SYNTHESIS) + assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE) + + @pytest.mark.asyncio + async def test_query_returns_answer_text(self): + """query() should return the synthesised answer.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + result = await rag.query( + query="What is the return policy?", + explain_callback=AsyncMock(), + ) + + assert result == "Items can be returned within 30 days for a full refund." + + @pytest.mark.asyncio + async def test_no_explain_callback_still_works(self): + """query() without explain_callback should return answer normally.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + result = await rag.query(query="What is the return policy?") + assert result == "Items can be returned within 30 days for a full refund." + + @pytest.mark.asyncio + async def test_all_triples_in_retrieval_graph(self): + """All emitted triples should be in the urn:graph:retrieval graph.""" + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + for event in events: + for t in event["triples"]: + assert t.g == "urn:graph:retrieval", ( + f"Triple {t.s.iri} {t.p.iri} should be in " + f"urn:graph:retrieval, got {t.g}" + ) diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 597d3366..00d8b72a 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -465,12 +465,15 @@ class TestQuery: return_value=(["entity1", "entity2"], ["concept1"]) ) - query.follow_edges_batch = AsyncMock(return_value={ - ("entity1", "predicate1", "object1"), - ("entity2", "predicate2", "object2") - }) + query.follow_edges_batch = AsyncMock(return_value=( + { + ("entity1", "predicate1", "object1"), + ("entity2", "predicate2", "object2") + }, + {} + )) - subgraph, entities, concepts = await query.get_subgraph("test query") + subgraph, term_map, entities, concepts = await query.get_subgraph("test query") query.get_entities.assert_called_once_with("test query") query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) @@ -503,7 +506,7 @@ class TestQuery: test_entities = ["entity1", "entity3"] test_concepts = ["concept1"] query.get_subgraph = AsyncMock( - return_value=(test_subgraph, test_entities, test_concepts) + return_value=(test_subgraph, {}, test_entities, test_concepts) ) async def mock_maybe_label(entity): diff --git a/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py b/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py new file mode 100644 index 00000000..603bd204 --- /dev/null +++ b/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py @@ -0,0 +1,358 @@ +""" +Tests that explain_triples are forwarded correctly through the graph-rag +service and client layers. + +Covers: +- Service: explain messages include triples from the provenance callback +- Client: explain_callback receives explain_triples from the response +- End-to-end: triples survive the full service → client → callback chain +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.schema import ( + GraphRagQuery, GraphRagResponse, + Triple, Term, IRI, LITERAL, +) +from trustgraph.base.graph_rag_client import GraphRagClient + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_triple(s_iri, p_iri, o_value, o_type=IRI): + """Create a Triple with IRI subject/predicate and typed object.""" + o = ( + Term(type=IRI, iri=o_value) if o_type == IRI + else Term(type=LITERAL, value=o_value) + ) + return Triple( + s=Term(type=IRI, iri=s_iri), + p=Term(type=IRI, iri=p_iri), + o=o, + ) + + +def sample_focus_triples(): + """Focus-style triples with a quoted triple (edge selection).""" + return [ + make_triple( + "urn:trustgraph:focus:abc", + "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", + "https://trustgraph.ai/ns/Focus", + ), + make_triple( + "urn:trustgraph:focus:abc", + "http://www.w3.org/ns/prov#wasDerivedFrom", + "urn:trustgraph:exploration:abc", + ), + make_triple( + "urn:trustgraph:focus:abc", + "https://trustgraph.ai/ns/selectedEdge", + "urn:trustgraph:edge-sel:abc:0", + ), + ] + + +def sample_question_triples(): + """Question-style triples.""" + return [ + make_triple( + "urn:trustgraph:question:abc", + "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", + "https://trustgraph.ai/ns/GraphRagQuestion", + ), + make_triple( + "urn:trustgraph:question:abc", + "https://trustgraph.ai/ns/query", + "What is quantum computing?", + o_type=LITERAL, + ), + ] + + +# --------------------------------------------------------------------------- +# Service-level: explain messages carry triples +# --------------------------------------------------------------------------- + +class TestGraphRagServiceExplainTriples: + """Test that the graph-rag service includes explain_triples in messages.""" + + @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') + @pytest.mark.asyncio + async def test_explain_messages_include_triples(self, mock_graph_rag_class): + """ + When the provenance callback is invoked with triples, the service + should include them in the explain response message. + """ + from trustgraph.retrieval.graph_rag.rag import Processor + + processor = Processor( + taskgroup=MagicMock(), + id="test-processor", + entity_limit=50, + triple_limit=30, + max_subgraph_size=150, + max_path_length=2, + ) + + mock_rag_instance = AsyncMock() + mock_graph_rag_class.return_value = mock_rag_instance + + question_triples = sample_question_triples() + focus_triples = sample_focus_triples() + + async def mock_query(**kwargs): + explain_callback = kwargs.get('explain_callback') + if explain_callback: + await explain_callback( + question_triples, "urn:trustgraph:question:abc" + ) + await explain_callback( + focus_triples, "urn:trustgraph:focus:abc" + ) + return "The answer." + + mock_rag_instance.query.side_effect = mock_query + + msg = MagicMock() + msg.value.return_value = GraphRagQuery( + query="What is quantum computing?", + user="trustgraph", + collection="default", + streaming=False, + ) + msg.properties.return_value = {"id": "test-id"} + + consumer = MagicMock() + flow = MagicMock() + mock_response = AsyncMock() + mock_provenance = AsyncMock() + + def flow_router(name): + if name == "response": + return mock_response + if name == "explainability": + return mock_provenance + return AsyncMock() + + flow.side_effect = flow_router + + await processor.on_request(msg, consumer, flow) + + # Find the explain messages + explain_msgs = [ + call[0][0] + for call in mock_response.send.call_args_list + if call[0][0].message_type == "explain" + ] + + assert len(explain_msgs) == 2 + + # First explain message should carry question triples + assert explain_msgs[0].explain_id == "urn:trustgraph:question:abc" + assert explain_msgs[0].explain_triples == question_triples + + # Second explain message should carry focus triples + assert explain_msgs[1].explain_id == "urn:trustgraph:focus:abc" + assert explain_msgs[1].explain_triples == focus_triples + + +# --------------------------------------------------------------------------- +# Client-level: explain_callback receives triples +# --------------------------------------------------------------------------- + +class TestGraphRagClientExplainForwarding: + """Test that GraphRagClient.rag() forwards explain_triples to callback.""" + + @pytest.mark.asyncio + async def test_explain_callback_receives_triples(self): + """ + The explain_callback should receive (explain_id, explain_graph, + explain_triples) — not just (explain_id, explain_graph). + """ + focus_triples = sample_focus_triples() + + # Simulate the response sequence the client would receive + responses = [ + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:focus:abc", + explain_graph="urn:graph:retrieval", + explain_triples=focus_triples, + ), + GraphRagResponse( + message_type="chunk", + response="The answer.", + end_of_stream=True, + ), + GraphRagResponse( + message_type="chunk", + response="", + end_of_session=True, + ), + ] + + # Capture what the explain_callback receives + received_calls = [] + + async def explain_callback(explain_id, explain_graph, explain_triples): + received_calls.append({ + "explain_id": explain_id, + "explain_graph": explain_graph, + "explain_triples": explain_triples, + }) + + # Patch self.request to feed responses to the recipient + client = GraphRagClient.__new__(GraphRagClient) + + async def mock_request(req, timeout=600, recipient=None): + for resp in responses: + done = await recipient(resp) + if done: + return resp + + client.request = mock_request + + result = await client.rag( + query="test", + explain_callback=explain_callback, + ) + + assert result == "The answer." + assert len(received_calls) == 1 + assert received_calls[0]["explain_id"] == "urn:trustgraph:focus:abc" + assert received_calls[0]["explain_graph"] == "urn:graph:retrieval" + assert received_calls[0]["explain_triples"] == focus_triples + + @pytest.mark.asyncio + async def test_explain_callback_receives_empty_triples(self): + """ + When an explain event has no triples, the callback should still + receive an empty list (not None or missing). + """ + responses = [ + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc", + explain_graph="urn:graph:retrieval", + explain_triples=[], + ), + GraphRagResponse( + message_type="chunk", + response="Answer.", + end_of_stream=True, + end_of_session=True, + ), + ] + + received_calls = [] + + async def explain_callback(explain_id, explain_graph, explain_triples): + received_calls.append(explain_triples) + + client = GraphRagClient.__new__(GraphRagClient) + + async def mock_request(req, timeout=600, recipient=None): + for resp in responses: + done = await recipient(resp) + if done: + return resp + + client.request = mock_request + + await client.rag(query="test", explain_callback=explain_callback) + + assert len(received_calls) == 1 + assert received_calls[0] == [] + + @pytest.mark.asyncio + async def test_multiple_explain_events_all_forward_triples(self): + """ + Each explain event in a session should forward its own triples. + """ + q_triples = sample_question_triples() + f_triples = sample_focus_triples() + + responses = [ + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc", + explain_graph="urn:graph:retrieval", + explain_triples=q_triples, + ), + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:focus:abc", + explain_graph="urn:graph:retrieval", + explain_triples=f_triples, + ), + GraphRagResponse( + message_type="chunk", + response="Answer.", + end_of_stream=True, + end_of_session=True, + ), + ] + + received_calls = [] + + async def explain_callback(explain_id, explain_graph, explain_triples): + received_calls.append({ + "explain_id": explain_id, + "explain_triples": explain_triples, + }) + + client = GraphRagClient.__new__(GraphRagClient) + + async def mock_request(req, timeout=600, recipient=None): + for resp in responses: + done = await recipient(resp) + if done: + return resp + + client.request = mock_request + + await client.rag(query="test", explain_callback=explain_callback) + + assert len(received_calls) == 2 + assert received_calls[0]["explain_id"] == "urn:trustgraph:question:abc" + assert received_calls[0]["explain_triples"] == q_triples + assert received_calls[1]["explain_id"] == "urn:trustgraph:focus:abc" + assert received_calls[1]["explain_triples"] == f_triples + + @pytest.mark.asyncio + async def test_no_explain_callback_does_not_error(self): + """ + When no explain_callback is provided, explain events should be + silently skipped without errors. + """ + responses = [ + GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc", + explain_graph="urn:graph:retrieval", + explain_triples=sample_question_triples(), + ), + GraphRagResponse( + message_type="chunk", + response="Answer.", + end_of_stream=True, + end_of_session=True, + ), + ] + + client = GraphRagClient.__new__(GraphRagClient) + + async def mock_request(req, timeout=600, recipient=None): + for resp in responses: + done = await recipient(resp) + if done: + return resp + + client.request = mock_request + + result = await client.rag(query="test") + assert result == "Answer." diff --git a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py new file mode 100644 index 00000000..36536f7d --- /dev/null +++ b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py @@ -0,0 +1,482 @@ +""" +Integration test: run a full GraphRag.query() with mocked subsidiary clients +and verify the explain_callback receives the complete provenance chain +in the correct order with correct structure. + +This tests the real query() method end-to-end, not just the triple builders. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock +from dataclasses import dataclass + +from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id +from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE, + TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT, + TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +@dataclass +class EmbeddingMatch: + """Mimics the result from graph_embeddings_client.query().""" + entity: Term + + +# --------------------------------------------------------------------------- +# Mock setup +# --------------------------------------------------------------------------- + +# A tiny knowledge graph: 2 entities, 3 edges +ENTITY_A = "http://example.com/QuantumComputing" +ENTITY_B = "http://example.com/Physics" +EDGE_1 = (ENTITY_A, "http://schema.org/relatedTo", ENTITY_B) +EDGE_2 = (ENTITY_A, "http://schema.org/name", "Quantum Computing") +EDGE_3 = (ENTITY_B, "http://schema.org/name", "Physics") + + +def make_schema_triple(s, p, o): + """Create a SchemaTriple from string values.""" + return SchemaTriple( + s=Term(type=IRI, iri=s), + p=Term(type=IRI, iri=p), + o=Term(type=IRI, iri=o) if o.startswith("http") else Term(type=LITERAL, value=o), + ) + + +def build_mock_clients(): + """ + Build mock clients that simulate a small knowledge graph query. + + Client call sequence during query(): + 1. prompt_client.prompt("extract-concepts", ...) -> concepts + 2. embeddings_client.embed(concepts) -> vectors + 3. graph_embeddings_client.query(vector, ...) -> entity matches + 4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch) + 5. triples_client.query(s, LABEL, ...) -> labels (maybe_label) + 6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges + 7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning + 8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns []) + 9. prompt_client.prompt("kg-synthesis", ...) -> final answer + """ + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + graph_embeddings_client = AsyncMock() + triples_client = AsyncMock() + + # 1. Concept extraction + prompt_responses = {} + prompt_responses["extract-concepts"] = "quantum computing\nphysics" + + # 2. Embedding vectors (simple fake vectors) + embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # 3. Entity lookup - return our two entities + graph_embeddings_client.query.return_value = [ + EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_A)), + EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)), + ] + + # 4. Triple queries (follow_edges_batch) - return our edges + kg_triples = [ + make_schema_triple(*EDGE_1), + make_schema_triple(*EDGE_2), + make_schema_triple(*EDGE_3), + ] + triples_client.query_stream.return_value = kg_triples + + # 5. Label resolution - return entity as its own label (simplify) + async def mock_label_query(s=None, p=None, o=None, limit=1, + user=None, collection=None, g=None): + return [] # No labels found, will fall back to URI + triples_client.query.side_effect = mock_label_query + + # 6+7. Edge scoring and reasoning: dynamically score/reason about + # whatever edges the query method sends us, since edge IDs are computed + # from str(Term) representations which include the full dataclass repr. + synthesis_answer = "Quantum computing applies physics principles to computation." + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return prompt_responses["extract-concepts"] + elif template_id == "kg-edge-scoring": + # Score all edges highly, using the IDs that GraphRag computed + edges = variables.get("knowledge", []) + return [ + {"id": e["id"], "score": 10 - i} + for i, e in enumerate(edges) + ] + elif template_id == "kg-edge-reasoning": + # Provide reasoning for each edge + edges = variables.get("knowledge", []) + return [ + {"id": e["id"], "reasoning": f"Relevant edge {i}"} + for i, e in enumerate(edges) + ] + elif template_id == "kg-synthesis": + return synthesis_answer + return "" + + prompt_client.prompt.side_effect = mock_prompt + + return prompt_client, embeddings_client, graph_embeddings_client, triples_client + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestGraphRagQueryProvenance: + """ + Run a real GraphRag.query() and verify the provenance chain emitted + via explain_callback. + """ + + @pytest.mark.asyncio + async def test_explain_callback_receives_five_events(self): + """query() should emit exactly 5 explain events.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, # skip semantic pre-filter for simplicity + ) + + assert len(events) == 5, ( + f"Expected 5 explain events (question, grounding, exploration, " + f"focus, synthesis), got {len(events)}" + ) + + @pytest.mark.asyncio + async def test_events_have_correct_types_in_order(self): + """ + Events should arrive as: + question, grounding, exploration, focus, synthesis. + """ + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + expected_types = [ + TG_GRAPH_RAG_QUESTION, + TG_GROUNDING, + TG_EXPLORATION, + TG_FOCUS, + TG_SYNTHESIS, + ] + + for i, expected_type in enumerate(expected_types): + uri = events[i]["explain_id"] + triples = events[i]["triples"] + assert has_type(triples, uri, expected_type), ( + f"Event {i} (uri={uri}) should have type {expected_type}" + ) + + @pytest.mark.asyncio + async def test_derivation_chain_links_correctly(self): + """ + Each event's URI should link to the previous via wasDerivedFrom: + grounding → question → (none) + exploration → grounding + focus → exploration + synthesis → focus + """ + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + uris = [e["explain_id"] for e in events] + all_triples = [] + for e in events: + all_triples.extend(e["triples"]) + + # question has no parent + assert derived_from(all_triples, uris[0]) is None + + # grounding → question + assert derived_from(all_triples, uris[1]) == uris[0] + + # exploration → grounding + assert derived_from(all_triples, uris[2]) == uris[1] + + # focus → exploration + assert derived_from(all_triples, uris[3]) == uris[2] + + # synthesis → focus + assert derived_from(all_triples, uris[4]) == uris[3] + + @pytest.mark.asyncio + async def test_question_event_carries_query_text(self): + """The question event should contain the original query string.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + q_uri = events[0]["explain_id"] + q_triples = events[0]["triples"] + t = find_triple(q_triples, TG_QUERY, q_uri) + assert t is not None + assert t.o.value == "What is quantum computing?" + + @pytest.mark.asyncio + async def test_grounding_carries_concepts(self): + """The grounding event should list extracted concepts.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + gnd_uri = events[1]["explain_id"] + gnd_triples = events[1]["triples"] + concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri) + concept_values = {t.o.value for t in concepts} + assert "quantum computing" in concept_values + assert "physics" in concept_values + + @pytest.mark.asyncio + async def test_exploration_has_edge_count(self): + """The exploration event should report how many edges were found.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + exp_uri = events[2]["explain_id"] + exp_triples = events[2]["triples"] + t = find_triple(exp_triples, TG_EDGE_COUNT, exp_uri) + assert t is not None + # Should be non-zero (we provided 3 edges, label edges filtered) + assert int(t.o.value) > 0 + + @pytest.mark.asyncio + async def test_focus_has_selected_edges_with_reasoning(self): + """ + The focus event should carry selected edges as quoted triples + with reasoning text. + """ + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + foc_uri = events[3]["explain_id"] + foc_triples = events[3]["triples"] + + # Should have selected edges + selected = find_triples(foc_triples, TG_SELECTED_EDGE, foc_uri) + assert len(selected) > 0, "Focus should have at least one selected edge" + + # Each edge selection should have a quoted triple + edge_t = find_triples(foc_triples, TG_EDGE) + assert len(edge_t) > 0, "Focus should have tg:edge with quoted triples" + for t in edge_t: + assert t.o.triple is not None, "tg:edge object must be a quoted triple" + + # Should have reasoning + reasoning = find_triples(foc_triples, TG_REASONING) + assert len(reasoning) > 0, "Focus should have reasoning for selected edges" + reasoning_texts = {t.o.value for t in reasoning} + assert any(r for r in reasoning_texts), "Reasoning should not be empty" + + @pytest.mark.asyncio + async def test_synthesis_is_answer_type(self): + """The synthesis event should have tg:Answer type.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + syn_uri = events[4]["explain_id"] + syn_triples = events[4]["triples"] + assert has_type(syn_triples, syn_uri, TG_SYNTHESIS) + assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE) + + @pytest.mark.asyncio + async def test_query_returns_answer_text(self): + """query() should still return the synthesised answer.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + result = await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + assert result == "Quantum computing applies physics principles to computation." + + @pytest.mark.asyncio + async def test_parent_uri_links_question_to_parent(self): + """When parent_uri is provided, question should derive from it.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + parent = "urn:trustgraph:agent:iteration:xyz" + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + parent_uri=parent, + ) + + q_uri = events[0]["explain_id"] + q_triples = events[0]["triples"] + assert derived_from(q_triples, q_uri) == parent + + @pytest.mark.asyncio + async def test_no_explain_callback_still_works(self): + """query() without explain_callback should return answer normally.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + result = await rag.query( + query="What is quantum computing?", + edge_score_limit=0, + ) + + assert result == "Quantum computing applies physics principles to computation." + + @pytest.mark.asyncio + async def test_all_triples_in_retrieval_graph(self): + """All emitted triples should be in the urn:graph:retrieval graph.""" + clients = build_mock_clients() + rag = GraphRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is quantum computing?", + explain_callback=explain_callback, + edge_score_limit=0, + ) + + for event in events: + for t in event["triples"]: + assert t.g == "urn:graph:retrieval", ( + f"Triple {t.s.iri} {t.p.iri} should be in " + f"urn:graph:retrieval, got {t.g}" + ) diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py index 32007943..9db23293 100644 --- a/trustgraph-base/trustgraph/base/graph_rag_client.py +++ b/trustgraph-base/trustgraph/base/graph_rag_client.py @@ -15,7 +15,7 @@ class GraphRagClient(RequestResponse): user: User identifier collection: Collection identifier chunk_callback: Optional async callback(text, end_of_stream) for text chunks - explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications + explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications timeout: Request timeout in seconds Returns: @@ -30,7 +30,7 @@ class GraphRagClient(RequestResponse): # Handle explain notifications if resp.message_type == 'explain': if explain_callback and resp.explain_id: - await explain_callback(resp.explain_id, resp.explain_graph) + await explain_callback(resp.explain_id, resp.explain_graph, resp.explain_triples) return False # Continue receiving # Handle text chunks diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 057376fb..365ea09d 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -43,7 +43,7 @@ class DocumentRagClient(BaseClient): user: User identifier collection: Collection identifier chunk_callback: Optional callback(text, end_of_stream) for text chunks - explain_callback: Optional callback(explain_id, explain_graph) for explain notifications + explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications timeout: Request timeout in seconds Returns: @@ -55,7 +55,7 @@ class DocumentRagClient(BaseClient): # Handle explain notifications (response is None/empty, explain_id present) if x.explain_id and not x.response: if explain_callback: - explain_callback(x.explain_id, x.explain_graph) + explain_callback(x.explain_id, x.explain_graph, x.explain_triples) return False # Continue receiving # Handle text chunks diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index 17d7b0f0..0d33bf91 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -47,7 +47,7 @@ class GraphRagClient(BaseClient): user: User identifier collection: Collection identifier chunk_callback: Optional callback(text, end_of_stream) for text chunks - explain_callback: Optional callback(explain_id, explain_graph) for explain notifications + explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications timeout: Request timeout in seconds Returns: @@ -59,7 +59,7 @@ class GraphRagClient(BaseClient): # Handle explain notifications if x.message_type == 'explain': if explain_callback and x.explain_id: - explain_callback(x.explain_id, x.explain_graph) + explain_callback(x.explain_id, x.explain_graph, x.explain_triples) return False # Continue receiving # Handle text chunks diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index f2e85eff..920a3482 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -465,11 +465,18 @@ def exploration_triples( return triples -def _quoted_triple(s: str, p: str, o: str) -> Term: - """Create a quoted triple term (RDF-star) from string values.""" +def _quoted_triple(s, p, o) -> Term: + """Create a quoted triple term (RDF-star). + + Accepts either Term objects (preserving original types) or plain + strings (treated as IRIs for backward compatibility). + """ + s_term = s if isinstance(s, Term) else _iri(s) + p_term = p if isinstance(p, Term) else _iri(p) + o_term = o if isinstance(o, Term) else _iri(o) return Term( type=TRIPLE, - triple=Triple(s=_iri(s), p=_iri(p), o=_iri(o)) + triple=Triple(s=s_term, p=p_term, o=o_term) ) diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 041558ec..6fd96ade 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -39,13 +39,14 @@ class KnowledgeQueryImpl: if respond: from ... schema import AgentResponse - async def explain_callback(explain_id, explain_graph): + async def explain_callback(explain_id, explain_graph, explain_triples=None): self.context.last_sub_explain_uri = explain_id await respond(AgentResponse( chunk_type="explain", content="", explain_id=explain_id, explain_graph=explain_graph, + explain_triples=explain_triples or [], )) if current_uri: diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 704613c6..5cf7b991 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -10,6 +10,7 @@ from collections import OrderedDict from datetime import datetime from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE +from ... knowledge import Uri, Literal # Provenance imports from trustgraph.provenance import ( @@ -46,6 +47,26 @@ def term_to_string(term): return term.iri or term.value or str(term) +def to_term(val): + """Convert a Uri, Literal, or string to a schema Term. + + The triples client returns Uri/Literal (str subclasses) rather than + Term objects. This converts them back so provenance quoted triples + preserve the correct type. + """ + if isinstance(val, Term): + return val + if isinstance(val, Uri): + return Term(type=IRI, iri=str(val)) + if isinstance(val, Literal): + return Term(type=LITERAL, value=str(val)) + # Fallback: treat as IRI if it looks like one, otherwise literal + s = str(val) + if s.startswith(("http://", "https://", "urn:")): + return Term(type=IRI, iri=s) + return Term(type=LITERAL, value=s) + + def edge_id(s, p, o): """Generate an 8-character hash ID for an edge (s, p, o).""" edge_str = f"{s}|{p}|{o}" @@ -258,10 +279,18 @@ class Query: return all_triples async def follow_edges_batch(self, entities, max_depth): - """Optimized iterative graph traversal with batching""" + """Optimized iterative graph traversal with batching. + + Returns: + tuple: (subgraph, term_map) where subgraph is a set of + (str, str, str) tuples and term_map maps each string tuple + to its original (Term, Term, Term) for type-preserving + provenance. + """ visited = set() current_level = set(entities) subgraph = set() + term_map = {} # (str, str, str) -> (Term, Term, Term) for depth in range(max_depth): if not current_level or len(subgraph) >= self.max_subgraph_size: @@ -282,6 +311,7 @@ class Query: for triple in triples: triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) subgraph.add(triple_tuple) + term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o)) # Collect entities for next level (only from s and o positions) if depth < max_depth - 1: # Don't collect for final depth @@ -293,13 +323,13 @@ class Query: # Stop if subgraph size limit reached if len(subgraph) >= self.max_subgraph_size: - return subgraph + return subgraph, term_map # Update for next iteration visited.update(current_level) current_level = next_level - return subgraph + return subgraph, term_map async def follow_edges(self, ent, subgraph, path_length): """Legacy method - replaced by follow_edges_batch""" @@ -311,7 +341,7 @@ class Query: return # For backward compatibility, convert to new approach - batch_result = await self.follow_edges_batch([ent], path_length) + batch_result, _ = await self.follow_edges_batch([ent], path_length) subgraph.update(batch_result) async def get_subgraph(self, query): @@ -319,9 +349,10 @@ class Query: Get subgraph by extracting concepts, finding entities, and traversing. Returns: - tuple: (subgraph, entities, concepts) where subgraph is a list of - (s, p, o) tuples, entities is the seed entity list, and concepts - is the extracted concept list. + tuple: (subgraph, term_map, entities, concepts) where subgraph is + a list of (s, p, o) string tuples, term_map maps each string + tuple to its original (Term, Term, Term), entities is the seed + entity list, and concepts is the extracted concept list. """ entities, concepts = await self.get_entities(query) @@ -330,9 +361,9 @@ class Query: logger.debug("Getting subgraph...") # Use optimized batch traversal instead of sequential processing - subgraph = await self.follow_edges_batch(entities, self.max_path_length) + subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length) - return list(subgraph), entities, concepts + return list(subgraph), term_map, entities, concepts async def resolve_labels_batch(self, entities): """Resolve labels for multiple entities in parallel""" @@ -353,7 +384,7 @@ class Query: - entities: list of seed entity URI strings - concepts: list of concept strings extracted from query """ - subgraph, entities, concepts = await self.get_subgraph(query) + subgraph, term_map, entities, concepts = await self.get_subgraph(query) # Filter out label triples filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] @@ -377,7 +408,7 @@ class Query: # Apply labels to subgraph and build URI mapping labeled_edges = [] - uri_map = {} # Maps edge_id of labeled edge -> original URI triple + uri_map = {} # Maps edge_id of labeled edge -> original Term triple for s, p, o in filtered_subgraph: labeled_triple = ( @@ -387,9 +418,9 @@ class Query: ) labeled_edges.append(labeled_triple) - # Map from labeled edge ID to original URIs + # Map from labeled edge ID to original Terms (preserving types) labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2]) - uri_map[labeled_eid] = (s, p, o) + uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o)) labeled_edges = labeled_edges[0:self.max_subgraph_size] @@ -419,12 +450,14 @@ class Query: # Step 1: Find subgraphs containing these edges via tg:contains subgraph_tasks = [] for s, p, o in edge_uris: + # s, p, o may be Term objects (preserving types) or strings + s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s) + p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p) + o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o) quoted = Term( type=TRIPLE, triple=SchemaTriple( - s=Term(type=IRI, iri=s), - p=Term(type=IRI, iri=p), - o=Term(type=IRI, iri=o), + s=s_term, p=p_term, o=o_term, ) ) subgraph_tasks.append(