mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Compare commits
3 commits
e899370d98
...
aff96e57cb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aff96e57cb | ||
|
|
e81418c58f | ||
|
|
4b5bfacab1 |
15 changed files with 2840 additions and 32 deletions
29
dev-tools/explainable-ai/README.md
Normal file
29
dev-tools/explainable-ai/README.md
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Explainable AI Demo
|
||||
|
||||
Demonstrates the TrustGraph streaming agent API with inline explainability
|
||||
events. Sends an agent query, receives streaming thinking/observation/answer
|
||||
chunks alongside RDF provenance events, then resolves the full provenance
|
||||
chain from answer back to source documents.
|
||||
|
||||
## What it shows
|
||||
|
||||
- Streaming agent responses (thinking, observation, answer)
|
||||
- Inline explainability events with RDF triples (W3C PROV + TrustGraph namespace)
|
||||
- Label resolution for entity and predicate URIs
|
||||
- Provenance chain traversal: subgraph → chunk → page → document
|
||||
- Source text retrieval from the librarian using chunk IDs
|
||||
|
||||
## Prerequisites
|
||||
|
||||
A running TrustGraph instance with at least one loaded document and a
|
||||
running flow. The default configuration connects to `ws://localhost:8088`.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
npm install
|
||||
node index.js
|
||||
```
|
||||
|
||||
Edit the `QUESTION` and `SOCKET_URL` constants at the top of `index.js`
|
||||
to change the query or target instance.
|
||||
552
dev-tools/explainable-ai/index.js
Normal file
552
dev-tools/explainable-ai/index.js
Normal file
|
|
@ -0,0 +1,552 @@
|
|||
|
||||
// ============================================================================
|
||||
// TrustGraph Explainability API Demo
|
||||
// ============================================================================
|
||||
//
|
||||
// This example demonstrates how to use the TrustGraph streaming agent API
|
||||
// with explainability events. It shows how to:
|
||||
//
|
||||
// 1. Send an agent query and receive streaming thinking/observation/answer
|
||||
// 2. Receive and parse explainability events as they arrive
|
||||
// 3. Resolve the provenance chain for knowledge graph edges:
|
||||
// subgraph -> chunk -> page -> document
|
||||
// 4. Fetch source text from the librarian using chunk IDs
|
||||
//
|
||||
// Explainability events use RDF triples (W3C PROV ontology + TrustGraph
|
||||
// namespace) to describe the retrieval pipeline. The key event types are:
|
||||
//
|
||||
// - AgentQuestion: The initial user query
|
||||
// - Analysis/ToolUse: Agent deciding which tool to invoke
|
||||
// - GraphRagQuestion: A sub-query sent to the Graph RAG pipeline
|
||||
// - Grounding: Concepts extracted from the query for graph traversal
|
||||
// - Exploration: Entities discovered during knowledge graph traversal
|
||||
// - Focus: The selected knowledge graph edges (triples) used for context
|
||||
// - Synthesis: The RAG answer synthesised from retrieved context
|
||||
// - Observation: The tool result returned to the agent
|
||||
// - Conclusion/Answer: The agent's final answer
|
||||
//
|
||||
// Each event carries RDF triples that link back through the provenance chain,
|
||||
// allowing full traceability from answer back to source documents.
|
||||
// ============================================================================
|
||||
|
||||
import { createTrustGraphSocket } from '@trustgraph/client';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const USER = "trustgraph";
|
||||
|
||||
// Simple question
|
||||
const QUESTION = "Tell me about the author of the document";
|
||||
|
||||
// Likely to trigger the deep research plan-and-execute pattern
|
||||
//const QUESTION = "Do deep research and explain the risks posed globalisation in the modern world";
|
||||
|
||||
const SOCKET_URL = "ws://localhost:8088/api/v1/socket";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RDF predicates and TrustGraph namespace constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type";
|
||||
const RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label";
|
||||
const PROV_DERIVED = "http://www.w3.org/ns/prov#wasDerivedFrom";
|
||||
|
||||
const TG_GROUNDING = "https://trustgraph.ai/ns/Grounding";
|
||||
const TG_CONCEPT = "https://trustgraph.ai/ns/concept";
|
||||
const TG_EXPLORATION = "https://trustgraph.ai/ns/Exploration";
|
||||
const TG_ENTITY = "https://trustgraph.ai/ns/entity";
|
||||
const TG_FOCUS = "https://trustgraph.ai/ns/Focus";
|
||||
const TG_EDGE = "https://trustgraph.ai/ns/edge";
|
||||
const TG_CONTAINS = "https://trustgraph.ai/ns/contains";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Utility: check whether a set of triples assigns a given RDF type to an ID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const isType = (triples, id, type) =>
|
||||
triples.some(t => t.s.i === id && t.p.i === RDF_TYPE && t.o.i === type);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Utility: word-wrap text for display
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const wrapText = (text, width, indent, maxLines) => {
|
||||
const clean = text.replace(/\s+/g, " ").trim();
|
||||
const lines = [];
|
||||
let remaining = clean;
|
||||
while (remaining.length > 0 && lines.length < maxLines) {
|
||||
if (remaining.length <= width) {
|
||||
lines.push(remaining);
|
||||
break;
|
||||
}
|
||||
let breakAt = remaining.lastIndexOf(" ", width);
|
||||
if (breakAt <= 0) breakAt = width;
|
||||
lines.push(remaining.substring(0, breakAt));
|
||||
remaining = remaining.substring(breakAt).trimStart();
|
||||
}
|
||||
if (remaining.length > 0 && lines.length >= maxLines)
|
||||
lines[lines.length - 1] += " ...";
|
||||
return lines.map(l => indent + l).join("\n");
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Connect to TrustGraph
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
console.log("=".repeat(80));
|
||||
console.log("TrustGraph Explainability API Demo");
|
||||
console.log("=".repeat(80));
|
||||
console.log(`Connecting to: ${SOCKET_URL}`);
|
||||
console.log(`Question: ${QUESTION}`);
|
||||
console.log("=".repeat(80));
|
||||
|
||||
const client = createTrustGraphSocket(USER, undefined, SOCKET_URL);
|
||||
|
||||
console.log("Connected, sending query...\n");
|
||||
|
||||
// Get a flow handle. Flows provide access to AI operations (agent, RAG,
|
||||
// text completion, etc.) as well as knowledge graph queries.
|
||||
const flow = client.flow("default");
|
||||
|
||||
// Get a librarian handle for fetching source document text.
|
||||
const librarian = client.librarian();
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Inline explain event printing
|
||||
// ---------------------------------------------------------------------------
|
||||
// Explain events arrive during streaming alongside thinking/observation/
|
||||
// answer chunks. We print a summary immediately and store them for
|
||||
// post-processing (label resolution and provenance lookups require async
|
||||
// queries that can't run inside the synchronous callback).
|
||||
|
||||
const explainEvents = [];
|
||||
|
||||
const printExplainInline = (explainEvent) => {
|
||||
const { explainId, explainTriples } = explainEvent;
|
||||
if (!explainTriples) return;
|
||||
|
||||
// Extract the RDF types assigned to the explain event's own ID.
|
||||
// Every explain event has rdf:type triples that identify what kind
|
||||
// of pipeline step it represents (Grounding, Exploration, Focus, etc.)
|
||||
const types = explainTriples
|
||||
.filter(t => t.s.i === explainId && t.p.i === RDF_TYPE)
|
||||
.map(t => t.o.i);
|
||||
|
||||
// Show short type names (e.g. "Grounding" instead of full URI)
|
||||
const shortTypes = types
|
||||
.map(t => t.split("/").pop().split("#").pop())
|
||||
.join(", ");
|
||||
console.log(` [explain] ${shortTypes}`);
|
||||
|
||||
// Grounding events contain the concepts extracted from the query.
|
||||
// These are the seed terms used to begin knowledge graph traversal.
|
||||
if (isType(explainTriples, explainId, TG_GROUNDING)) {
|
||||
const concepts = explainTriples
|
||||
.filter(t => t.s.i === explainId && t.p.i === TG_CONCEPT)
|
||||
.map(t => t.o.v);
|
||||
console.log(` Grounding concepts: ${concepts.join(", ")}`);
|
||||
}
|
||||
|
||||
// Exploration events list the entities found during graph traversal.
|
||||
// We show the count here; labelled names are printed after resolution.
|
||||
if (isType(explainTriples, explainId, TG_EXPLORATION)) {
|
||||
const count = explainTriples
|
||||
.filter(t => t.s.i === explainId && t.p.i === TG_ENTITY).length;
|
||||
console.log(` Entities: ${count} found (see below)`);
|
||||
}
|
||||
};
|
||||
|
||||
const collectExplain = (explainEvent) => {
|
||||
printExplainInline(explainEvent);
|
||||
explainEvents.push(explainEvent);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Label resolution
|
||||
// ---------------------------------------------------------------------------
|
||||
// Many explain triples reference entities and predicates by URI. We query
|
||||
// the knowledge graph for rdfs:label to get human-readable names.
|
||||
|
||||
const resolveLabels = async (uris) => {
|
||||
const labels = new Map();
|
||||
await Promise.all(uris.map(async (uri) => {
|
||||
try {
|
||||
const results = await flow.triplesQuery(
|
||||
{ t: "i", i: uri },
|
||||
{ t: "i", i: RDFS_LABEL },
|
||||
);
|
||||
if (results.length > 0) {
|
||||
labels.set(uri, results[0].o.v);
|
||||
}
|
||||
} catch (e) {
|
||||
// No label found, fall back to URI
|
||||
}
|
||||
}));
|
||||
return labels;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Provenance resolution for knowledge graph edges
|
||||
// ---------------------------------------------------------------------------
|
||||
// Focus events contain the knowledge graph triples (edges) that were selected
|
||||
// as context for the RAG answer. Each edge can be traced back through the
|
||||
// provenance chain to the original source document:
|
||||
//
|
||||
// subgraph --contains--> <<edge triple>> (RDF-star triple term)
|
||||
// subgraph --wasDerivedFrom--> chunk (text chunk)
|
||||
// chunk --wasDerivedFrom--> page (document page)
|
||||
// page --wasDerivedFrom--> document (original document)
|
||||
//
|
||||
// The chunk URI also serves as the content ID in the librarian, so it can
|
||||
// be used to fetch the actual source text.
|
||||
|
||||
const resolveEdgeSources = async (edgeTriples) => {
|
||||
const iri = (uri) => ({ t: "i", i: uri });
|
||||
const sources = new Map();
|
||||
|
||||
await Promise.all(edgeTriples.map(async (tr) => {
|
||||
const key = JSON.stringify(tr);
|
||||
try {
|
||||
// Step 1: Find the subgraph that contains this edge triple.
|
||||
// The query uses an RDF-star triple term as the object: the
|
||||
// knowledge graph stores subgraph -> contains -> <<s, p, o>>.
|
||||
const subgraphResults = await flow.triplesQuery(
|
||||
undefined,
|
||||
iri(TG_CONTAINS),
|
||||
{ t: "t", tr },
|
||||
);
|
||||
if (subgraphResults.length === 0) {
|
||||
if (tr.o.t === "l" || tr.o.t === "i") {
|
||||
console.log(` No source match for triple:`);
|
||||
console.log(` s: ${tr.s.i}`);
|
||||
console.log(` p: ${tr.p.i}`);
|
||||
console.log(` o: ${JSON.stringify(tr.o)}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const subgraph = subgraphResults[0].s.i;
|
||||
|
||||
// Step 2: Walk wasDerivedFrom chain: subgraph -> chunk
|
||||
const chunkResults = await flow.triplesQuery(
|
||||
iri(subgraph), iri(PROV_DERIVED),
|
||||
);
|
||||
if (chunkResults.length === 0) {
|
||||
sources.set(key, { subgraph });
|
||||
return;
|
||||
}
|
||||
const chunk = chunkResults[0].o.i;
|
||||
|
||||
// Step 3: chunk -> page
|
||||
const pageResults = await flow.triplesQuery(
|
||||
iri(chunk), iri(PROV_DERIVED),
|
||||
);
|
||||
if (pageResults.length === 0) {
|
||||
sources.set(key, { subgraph, chunk });
|
||||
return;
|
||||
}
|
||||
const page = pageResults[0].o.i;
|
||||
|
||||
// Step 4: page -> document
|
||||
const docResults = await flow.triplesQuery(
|
||||
iri(page), iri(PROV_DERIVED),
|
||||
);
|
||||
const document = docResults.length > 0 ? docResults[0].o.i : undefined;
|
||||
|
||||
sources.set(key, { subgraph, chunk, page, document });
|
||||
} catch (e) {
|
||||
// Query failed, skip this edge
|
||||
}
|
||||
}));
|
||||
|
||||
return sources;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Collect URIs that need label resolution
|
||||
// ---------------------------------------------------------------------------
|
||||
// Scans explain events for entity URIs (from Exploration events) and edge
|
||||
// term URIs (from Focus events) so we can batch-resolve their labels.
|
||||
|
||||
const collectUris = (events) => {
|
||||
const uris = new Set();
|
||||
for (const { explainId, explainTriples } of events) {
|
||||
if (!explainTriples) continue;
|
||||
|
||||
// Entity URIs from exploration
|
||||
if (isType(explainTriples, explainId, TG_EXPLORATION)) {
|
||||
for (const t of explainTriples) {
|
||||
if (t.s.i === explainId && t.p.i === TG_ENTITY)
|
||||
uris.add(t.o.i);
|
||||
}
|
||||
}
|
||||
|
||||
// Subject, predicate, and object URIs from focus edge triples
|
||||
if (isType(explainTriples, explainId, TG_FOCUS)) {
|
||||
for (const t of explainTriples) {
|
||||
if (t.p.i === TG_EDGE && t.o.t === "t") {
|
||||
const tr = t.o.tr;
|
||||
if (tr.s.t === "i") uris.add(tr.s.i);
|
||||
if (tr.p.t === "i") uris.add(tr.p.i);
|
||||
if (tr.o.t === "i") uris.add(tr.o.i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return uris;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Collect edge triples from Focus events
|
||||
// ---------------------------------------------------------------------------
|
||||
// Focus events contain selectedEdge -> edge relationships. Each edge's
|
||||
// object is an RDF-star triple term ({t: "t", tr: {s, p, o}}) representing
|
||||
// the actual knowledge graph triple used as RAG context.
|
||||
|
||||
const collectEdgeTriples = (events) => {
|
||||
const edges = [];
|
||||
for (const { explainId, explainTriples } of events) {
|
||||
if (!explainTriples) continue;
|
||||
if (isType(explainTriples, explainId, TG_FOCUS)) {
|
||||
for (const t of explainTriples) {
|
||||
if (t.p.i === TG_EDGE && t.o.t === "t")
|
||||
edges.push(t.o.tr);
|
||||
}
|
||||
}
|
||||
}
|
||||
return edges;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Print knowledge graph edges with provenance
|
||||
// ---------------------------------------------------------------------------
|
||||
// Displays each edge triple with resolved labels and its source location
|
||||
// (chunk -> page -> document).
|
||||
|
||||
const printFocusEdges = (events, labels, edgeSources) => {
|
||||
const label = (uri) => labels.get(uri) || uri;
|
||||
|
||||
for (const { explainId, explainTriples } of events) {
|
||||
if (!explainTriples) continue;
|
||||
if (!isType(explainTriples, explainId, TG_FOCUS)) continue;
|
||||
|
||||
const termValue = (term) =>
|
||||
term.t === "i" ? label(term.i) : (term.v || "?");
|
||||
|
||||
const edges = explainTriples
|
||||
.filter(t => t.p.i === TG_EDGE && t.o.t === "t")
|
||||
.map(t => t.o.tr);
|
||||
|
||||
const display = edges.slice(0, 20);
|
||||
for (const tr of display) {
|
||||
console.log(` ${termValue(tr.s)} -> ${termValue(tr.p)} -> ${termValue(tr.o)}`);
|
||||
const src = edgeSources.get(JSON.stringify(tr));
|
||||
if (src) {
|
||||
const parts = [];
|
||||
if (src.chunk) parts.push(label(src.chunk));
|
||||
if (src.page) parts.push(label(src.page));
|
||||
if (src.document) parts.push(label(src.document));
|
||||
if (parts.length > 0)
|
||||
console.log(` Source: ${parts.join(" -> ")}`);
|
||||
}
|
||||
}
|
||||
if (edges.length > 20)
|
||||
console.log(` ... and ${edges.length - 20} more`);
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fetch chunk text from the librarian
|
||||
// ---------------------------------------------------------------------------
|
||||
// The chunk URI (e.g. urn:chunk:UUID) serves as a universal ID that ties
|
||||
// together provenance metadata, embeddings, and the source text content.
|
||||
// The librarian stores the original text keyed by this same URI, so we
|
||||
// can retrieve it with streamDocument(chunkUri).
|
||||
|
||||
const fetchChunkText = (chunkUri) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
let text = "";
|
||||
librarian.streamDocument(
|
||||
chunkUri,
|
||||
(content, chunkIndex, totalChunks, complete) => {
|
||||
text += content;
|
||||
if (complete) resolve(text);
|
||||
},
|
||||
(error) => reject(error),
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
// ===========================================================================
|
||||
// Send the agent query
|
||||
// ===========================================================================
|
||||
// The agent callback receives four types of streaming content:
|
||||
// - think: the agent's reasoning (chain-of-thought)
|
||||
// - observe: tool results returned to the agent
|
||||
// - answer: the final answer being generated
|
||||
// - error: any errors during processing
|
||||
//
|
||||
// The onExplain callback fires for each explainability event, delivering
|
||||
// RDF triples that describe what happened at each pipeline stage.
|
||||
|
||||
let thought = "";
|
||||
let obs = "";
|
||||
let ans = "";
|
||||
|
||||
await flow.agent(
|
||||
|
||||
QUESTION,
|
||||
|
||||
// Think callback: agent reasoning / chain-of-thought
|
||||
(chunk, complete, messageId, metadata) => {
|
||||
thought += chunk;
|
||||
if (complete) {
|
||||
console.log("\nThinking:", thought, "\n");
|
||||
thought = "";
|
||||
}
|
||||
},
|
||||
|
||||
// Observe callback: tool results returned to the agent
|
||||
(chunk, complete, messageId, metadata) => {
|
||||
obs += chunk;
|
||||
if (complete) {
|
||||
console.log("\nObservation:", obs, "\n");
|
||||
obs = "";
|
||||
}
|
||||
},
|
||||
|
||||
// Answer callback: the agent's final response
|
||||
(chunk, complete, messageId, metadata) => {
|
||||
ans += chunk;
|
||||
if (complete) {
|
||||
console.log("\nAnswer:", ans, "\n");
|
||||
ans = "";
|
||||
}
|
||||
},
|
||||
|
||||
// Error callback
|
||||
(error) => {
|
||||
console.log(JSON.stringify({ type: "error", error }, null, 2));
|
||||
},
|
||||
|
||||
// Explain callback: explainability events with RDF triples
|
||||
(explainEvent) => {
|
||||
collectExplain(explainEvent);
|
||||
}
|
||||
|
||||
);
|
||||
|
||||
// ===========================================================================
|
||||
// Post-processing: resolve labels, provenance, and source text
|
||||
// ===========================================================================
|
||||
// After the agent query completes, we have all the explain events. Now we
|
||||
// can make async queries to:
|
||||
// 1. Trace each edge back to its source document (provenance chain)
|
||||
// 2. Resolve URIs to human-readable labels
|
||||
// 3. Fetch the original text for each source chunk
|
||||
|
||||
console.log("Resolving provenance...\n");
|
||||
|
||||
// Resolve the provenance chain for each knowledge graph edge
|
||||
const edgeTriples = collectEdgeTriples(explainEvents);
|
||||
const edgeSources = await resolveEdgeSources(edgeTriples);
|
||||
|
||||
// Collect all URIs that need labels: entities, edge terms, and source URIs
|
||||
const uris = collectUris(explainEvents);
|
||||
for (const src of edgeSources.values()) {
|
||||
if (src.chunk) uris.add(src.chunk);
|
||||
if (src.page) uris.add(src.page);
|
||||
if (src.document) uris.add(src.document);
|
||||
}
|
||||
const labels = await resolveLabels([...uris]);
|
||||
|
||||
const label = (uri) => labels.get(uri) || uri;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Display: Entities retrieved during graph exploration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
for (const { explainId, explainTriples } of explainEvents) {
|
||||
if (!explainTriples) continue;
|
||||
if (!isType(explainTriples, explainId, TG_EXPLORATION)) continue;
|
||||
const entities = explainTriples
|
||||
.filter(t => t.s.i === explainId && t.p.i === TG_ENTITY)
|
||||
.map(t => label(t.o.i));
|
||||
const display = entities.slice(0, 10);
|
||||
console.log("=".repeat(80));
|
||||
console.log("Entities Retrieved");
|
||||
console.log("=".repeat(80));
|
||||
console.log(` ${entities.length} entities: ${display.join(", ")}${entities.length > 10 ? ", ..." : ""}`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Display: Knowledge graph edges with provenance
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
console.log("\n" + "=".repeat(80));
|
||||
console.log("Knowledge Graph Edges");
|
||||
console.log("=".repeat(80));
|
||||
printFocusEdges(explainEvents, labels, edgeSources);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Display: Source text for each chunk referenced by the edges
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const uniqueChunks = new Set();
|
||||
for (const src of edgeSources.values()) {
|
||||
if (src.chunk) uniqueChunks.add(src.chunk);
|
||||
}
|
||||
|
||||
console.log(`\nFetching text for ${uniqueChunks.size} source chunks...`);
|
||||
const chunkTexts = new Map();
|
||||
await Promise.all([...uniqueChunks].map(async (chunkUri) => {
|
||||
try {
|
||||
// streamDocument returns base64-encoded content
|
||||
const text = await fetchChunkText(chunkUri);
|
||||
chunkTexts.set(chunkUri, text);
|
||||
} catch (e) {
|
||||
// Failed to fetch text for this chunk
|
||||
}
|
||||
}));
|
||||
|
||||
console.log("\n" + "=".repeat(80));
|
||||
console.log("Sources");
|
||||
console.log("=".repeat(80));
|
||||
|
||||
let sourceIndex = 0;
|
||||
for (const chunkUri of uniqueChunks) {
|
||||
sourceIndex++;
|
||||
const chunkLabel = labels.get(chunkUri) || chunkUri;
|
||||
|
||||
// Find the page and document labels for this chunk
|
||||
let pageLabel, docLabel;
|
||||
for (const src of edgeSources.values()) {
|
||||
if (src.chunk === chunkUri) {
|
||||
if (src.page) pageLabel = labels.get(src.page) || src.page;
|
||||
if (src.document) docLabel = labels.get(src.document) || src.document;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(`\n [${sourceIndex}] ${docLabel || "?"} / ${pageLabel || "?"} / ${chunkLabel}`);
|
||||
console.log(" " + "-".repeat(70));
|
||||
|
||||
// Decode the base64 content and display a wrapped snippet
|
||||
const b64 = chunkTexts.get(chunkUri);
|
||||
if (b64) {
|
||||
const text = Buffer.from(b64, "base64").toString("utf-8");
|
||||
console.log(wrapText(text, 76, " ", 6));
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Clean up
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
console.log("\n" + "=".repeat(80));
|
||||
console.log("Query complete");
|
||||
console.log("=".repeat(80));
|
||||
|
||||
client.close();
|
||||
process.exit(0);
|
||||
13
dev-tools/explainable-ai/package.json
Normal file
13
dev-tools/explainable-ai/package.json
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"name": "explain-api-example",
|
||||
"version": "1.0.0",
|
||||
"description": "TrustGraph explainability API example",
|
||||
"main": "index.js",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"test": "echo \"Error: no test specified\" && exit 1"
|
||||
},
|
||||
"dependencies": {
|
||||
"@trustgraph/client": "^1.7.2"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,655 @@
|
|||
"""
|
||||
Integration tests for agent-orchestrator provenance chains.
|
||||
|
||||
Tests all three patterns by calling iterate() with mocked dependencies
|
||||
and verifying the explain events emitted via respond().
|
||||
|
||||
Provenance chains:
|
||||
React: session → iteration → (observation or final)
|
||||
Plan: session → plan → step-result(s) → synthesis
|
||||
Supervisor: session → decomposition → finding(s) → synthesis
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
# Agent provenance type constants
|
||||
from trustgraph.provenance.namespaces import (
|
||||
TG_AGENT_QUESTION,
|
||||
TG_ANALYSIS,
|
||||
TG_TOOL_USE,
|
||||
TG_OBSERVATION_TYPE,
|
||||
TG_CONCLUSION,
|
||||
TG_DECOMPOSITION,
|
||||
TG_FINDING,
|
||||
TG_PLAN_TYPE,
|
||||
TG_STEP_RESULT,
|
||||
TG_SYNTHESIS as TG_AGENT_SYNTHESIS,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
def collect_explain_events(respond_mock):
|
||||
"""Extract explain events from a respond mock's call history."""
|
||||
events = []
|
||||
for call in respond_mock.call_args_list:
|
||||
resp = call[0][0]
|
||||
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
|
||||
events.append({
|
||||
"explain_id": resp.explain_id,
|
||||
"explain_graph": resp.explain_graph,
|
||||
"triples": resp.explain_triples,
|
||||
})
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock processor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_mock_processor(tools=None):
|
||||
"""Build a mock processor with the minimal interface patterns need."""
|
||||
processor = MagicMock()
|
||||
processor.max_iterations = 10
|
||||
processor.save_answer_content = AsyncMock()
|
||||
|
||||
# provenance_session_uri must return a real URI
|
||||
def mock_session_uri(session_id):
|
||||
return f"urn:trustgraph:agent:session:{session_id}"
|
||||
processor.provenance_session_uri.side_effect = mock_session_uri
|
||||
|
||||
# Agent with tools
|
||||
agent = MagicMock()
|
||||
agent.tools = tools or {}
|
||||
agent.additional_context = ""
|
||||
processor.agent = agent
|
||||
|
||||
# Aggregator for supervisor
|
||||
processor.aggregator = MagicMock()
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def make_mock_flow():
|
||||
"""Build a mock flow that returns async mock producers."""
|
||||
producers = {}
|
||||
|
||||
def flow_factory(name):
|
||||
if name not in producers:
|
||||
producers[name] = AsyncMock()
|
||||
return producers[name]
|
||||
|
||||
flow = MagicMock(side_effect=flow_factory)
|
||||
flow._producers = producers
|
||||
return flow
|
||||
|
||||
|
||||
def make_base_request(**kwargs):
|
||||
"""Build a minimal AgentRequest."""
|
||||
defaults = dict(
|
||||
question="What is quantum computing?",
|
||||
state="",
|
||||
group=[],
|
||||
history=[],
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id="test-session-123",
|
||||
conversation_id="",
|
||||
pattern="react",
|
||||
task_type="",
|
||||
framing="",
|
||||
correlation_id="",
|
||||
parent_session_id="",
|
||||
subagent_goal="",
|
||||
expected_siblings=0,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return AgentRequest(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# React pattern tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReactPatternProvenance:
|
||||
"""
|
||||
React pattern chain: session → iteration → final
|
||||
(single iteration ending in Final answer)
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_iteration_final_answer(self):
|
||||
"""
|
||||
A single react iteration that produces a Final answer should emit:
|
||||
session, iteration, final — in that order.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.react.types import Action, Final
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = ReactPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
request = make_base_request()
|
||||
|
||||
# Mock AgentManager.react to call on_action then return Final
|
||||
with patch(
|
||||
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
final = Final(
|
||||
thought="I know the answer",
|
||||
final="Quantum computing uses qubits.",
|
||||
)
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
# Simulate the on_action callback before returning Final
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="I know the answer",
|
||||
name="final",
|
||||
arguments={},
|
||||
observation="",
|
||||
))
|
||||
return final
|
||||
|
||||
mock_am.react.side_effect = mock_react
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have 3 events: session, iteration, final
|
||||
assert len(events) == 3, (
|
||||
f"Expected 3 explain events (session, iteration, final), "
|
||||
f"got {len(events)}: {[e['explain_id'] for e in events]}"
|
||||
)
|
||||
|
||||
# Check types
|
||||
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
|
||||
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
|
||||
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_CONCLUSION)
|
||||
|
||||
# Check derivation chain
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
|
||||
# iteration derives from session
|
||||
assert derived_from(all_triples, uris[1]) == uris[0]
|
||||
# final derives from session (first iteration, no prior observation)
|
||||
assert derived_from(all_triples, uris[2]) == uris[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_iteration_with_tool_call(self):
|
||||
"""
|
||||
A react iteration that calls a tool (not Final) should emit:
|
||||
session, iteration, observation — then call next() for continuation.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.react.types import Action
|
||||
|
||||
# Create a mock tool
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "knowledge-query"
|
||||
mock_tool.description = "Query the knowledge base"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="The answer is 42")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = make_mock_processor(
|
||||
tools={"knowledge-query": mock_tool}
|
||||
)
|
||||
pattern = ReactPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
request = make_base_request()
|
||||
|
||||
action = Action(
|
||||
thought="I need to look this up",
|
||||
name="knowledge-query",
|
||||
arguments={"question": "What is quantum computing?"},
|
||||
observation="Quantum computing uses qubits.",
|
||||
)
|
||||
|
||||
with patch(
|
||||
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
return action
|
||||
|
||||
mock_am.react.side_effect = mock_react
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have 3 events: session, iteration, observation
|
||||
assert len(events) == 3, (
|
||||
f"Expected 3 explain events (session, iteration, observation), "
|
||||
f"got {len(events)}: {[e['explain_id'] for e in events]}"
|
||||
)
|
||||
|
||||
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
|
||||
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
|
||||
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_OBSERVATION_TYPE)
|
||||
|
||||
# next() should have been called to continue the loop
|
||||
assert next_fn.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
"""All explain triples should be in urn:graph:retrieval."""
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.react.types import Action, Final
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = ReactPattern(processor)
|
||||
respond = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
with patch(
|
||||
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="done", name="final",
|
||||
arguments={}, observation="",
|
||||
))
|
||||
return Final(thought="done", final="answer")
|
||||
|
||||
mock_am.react.side_effect = mock_react
|
||||
await pattern.iterate(
|
||||
make_base_request(), respond, AsyncMock(), flow,
|
||||
)
|
||||
|
||||
for event in collect_explain_events(respond):
|
||||
for t in event["triples"]:
|
||||
assert t.g == GRAPH_RETRIEVAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plan-then-execute pattern tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPlanPatternProvenance:
|
||||
"""
|
||||
Plan pattern chain:
|
||||
Planning iteration: session → plan
|
||||
Execution iterations: step-result(s) → synthesis
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_planning_iteration_emits_session_and_plan(self):
|
||||
"""
|
||||
The first iteration (planning) should emit:
|
||||
session, plan — then call next() with the plan in history.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = PlanThenExecutePattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt client for plan creation
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
]
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = make_base_request(pattern="plan")
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have 2 events: session, plan
|
||||
assert len(events) == 2, (
|
||||
f"Expected 2 explain events (session, plan), "
|
||||
f"got {len(events)}: {[e['explain_id'] for e in events]}"
|
||||
)
|
||||
|
||||
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
|
||||
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_PLAN_TYPE)
|
||||
|
||||
# Plan should derive from session
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
|
||||
|
||||
# next() should have been called with plan in history
|
||||
assert next_fn.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execution_iteration_emits_step_result(self):
|
||||
"""
|
||||
An execution iteration should emit a step-result event.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
|
||||
# Create a mock tool
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "knowledge-query"
|
||||
mock_tool.description = "Query KB"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="Found the answer")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = make_mock_processor(
|
||||
tools={"knowledge-query": mock_tool}
|
||||
)
|
||||
pattern = PlanThenExecutePattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt for step execution
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = {
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
}
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
# Request with plan already in history (second iteration)
|
||||
plan_step = AgentStep(
|
||||
thought="Created plan",
|
||||
action="plan",
|
||||
arguments={},
|
||||
observation="[]",
|
||||
step_type="plan",
|
||||
plan=[
|
||||
PlanStep(goal="Find info", tool_hint="knowledge-query",
|
||||
depends_on=[], status="pending", result=""),
|
||||
],
|
||||
)
|
||||
request = make_base_request(
|
||||
pattern="plan",
|
||||
history=[plan_step],
|
||||
)
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have step-result (no session on iteration > 1)
|
||||
step_events = [
|
||||
e for e in events
|
||||
if has_type(e["triples"], e["explain_id"], TG_STEP_RESULT)
|
||||
]
|
||||
assert len(step_events) == 1, (
|
||||
f"Expected 1 step-result event, got {len(step_events)}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_after_all_steps_complete(self):
|
||||
"""
|
||||
When all plan steps are completed, synthesis should be emitted.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = PlanThenExecutePattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The synthesised answer."
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
# Request with all steps completed
|
||||
exec_step = AgentStep(
|
||||
thought="Executing step",
|
||||
action="knowledge-query",
|
||||
arguments={},
|
||||
observation="Result",
|
||||
step_type="execute",
|
||||
plan=[
|
||||
PlanStep(goal="Find info", tool_hint="knowledge-query",
|
||||
depends_on=[], status="completed", result="Found it"),
|
||||
],
|
||||
)
|
||||
request = make_base_request(
|
||||
pattern="plan",
|
||||
history=[exec_step],
|
||||
)
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have synthesis event
|
||||
synth_events = [
|
||||
e for e in events
|
||||
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
|
||||
]
|
||||
assert len(synth_events) == 1, (
|
||||
f"Expected 1 synthesis event, got {len(synth_events)}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supervisor pattern tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSupervisorPatternProvenance:
|
||||
"""
|
||||
Supervisor pattern chain:
|
||||
Decompose: session → decomposition
|
||||
(Fan-out to subagents happens externally)
|
||||
Synthesise: synthesis (derives from findings)
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_emits_session_and_decomposition(self):
|
||||
"""
|
||||
The decompose phase should emit: session, decomposition.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = SupervisorPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt for decomposition
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
]
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = make_base_request(pattern="supervisor")
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have 2 events: session, decomposition
|
||||
assert len(events) == 2, (
|
||||
f"Expected 2 explain events (session, decomposition), "
|
||||
f"got {len(events)}: {[e['explain_id'] for e in events]}"
|
||||
)
|
||||
|
||||
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
|
||||
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_DECOMPOSITION)
|
||||
|
||||
# Decomposition derives from session
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_emits_after_subagent_results(self):
|
||||
"""
|
||||
When subagent results arrive, synthesis should be emitted.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = SupervisorPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The combined answer."
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
# Request with subagent results in history
|
||||
synth_step = AgentStep(
|
||||
thought="",
|
||||
action="synthesise",
|
||||
arguments={},
|
||||
observation="",
|
||||
step_type="synthesise",
|
||||
subagent_results={
|
||||
"What is quantum computing?": "It uses qubits",
|
||||
"What are qubits?": "Quantum bits",
|
||||
},
|
||||
)
|
||||
request = make_base_request(
|
||||
pattern="supervisor",
|
||||
history=[synth_step],
|
||||
)
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have synthesis event (no session on iteration > 1)
|
||||
synth_events = [
|
||||
e for e in events
|
||||
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
|
||||
]
|
||||
assert len(synth_events) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_fans_out_subagents(self):
|
||||
"""The decompose phase should call next() for each subagent goal."""
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = SupervisorPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = make_base_request(pattern="supervisor")
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
# 3 subagent requests fanned out
|
||||
assert next_fn.call_count == 3
|
||||
295
tests/unit/test_provenance/test_graph_rag_chain.py
Normal file
295
tests/unit/test_provenance/test_graph_rag_chain.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
"""
|
||||
Structural test for the graph-rag provenance chain.
|
||||
|
||||
Verifies that a complete graph-rag query produces the expected
|
||||
provenance chain:
|
||||
|
||||
question → grounding → exploration → focus → synthesis
|
||||
|
||||
Each step must:
|
||||
- Have the correct rdf:type
|
||||
- Link to its predecessor via prov:wasDerivedFrom
|
||||
- Carry expected domain-specific data
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.provenance.triples import (
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
)
|
||||
from trustgraph.provenance.uris import (
|
||||
question_uri,
|
||||
grounding_uri,
|
||||
exploration_uri,
|
||||
focus_uri,
|
||||
synthesis_uri,
|
||||
)
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_GRAPH_RAG_QUESTION, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT,
|
||||
PROV_STARTED_AT_TIME,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SESSION_ID = "test-session-1234"
|
||||
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
"""Find first triple matching predicate (and optionally subject)."""
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
"""Find all triples matching predicate (and optionally subject)."""
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
"""Check if subject has the given rdf:type."""
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
"""Get the wasDerivedFrom target URI for a subject."""
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build the full chain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def chain():
|
||||
"""Build all provenance triples for a complete graph-rag query."""
|
||||
q_uri = question_uri(SESSION_ID)
|
||||
gnd_uri = grounding_uri(SESSION_ID)
|
||||
exp_uri = exploration_uri(SESSION_ID)
|
||||
foc_uri = focus_uri(SESSION_ID)
|
||||
syn_uri = synthesis_uri(SESSION_ID)
|
||||
|
||||
q = question_triples(q_uri, "What is quantum computing?", "2026-01-01T00:00:00Z")
|
||||
gnd = grounding_triples(gnd_uri, q_uri, ["quantum", "computing"])
|
||||
exp = exploration_triples(
|
||||
exp_uri, gnd_uri, edge_count=42,
|
||||
entities=["urn:entity:1", "urn:entity:2"],
|
||||
)
|
||||
foc = focus_triples(
|
||||
foc_uri, exp_uri,
|
||||
selected_edges_with_reasoning=[
|
||||
{
|
||||
"edge": (
|
||||
"http://example.com/QuantumComputing",
|
||||
"http://schema.org/relatedTo",
|
||||
"http://example.com/Physics",
|
||||
),
|
||||
"reasoning": "Directly relevant to the query",
|
||||
},
|
||||
{
|
||||
"edge": (
|
||||
"http://example.com/QuantumComputing",
|
||||
"http://schema.org/name",
|
||||
"Quantum Computing",
|
||||
),
|
||||
"reasoning": "Provides the entity label",
|
||||
},
|
||||
],
|
||||
session_id=SESSION_ID,
|
||||
)
|
||||
syn = synthesis_triples(syn_uri, foc_uri, document_id="urn:doc:answer-1")
|
||||
|
||||
return {
|
||||
"uris": {
|
||||
"question": q_uri,
|
||||
"grounding": gnd_uri,
|
||||
"exploration": exp_uri,
|
||||
"focus": foc_uri,
|
||||
"synthesis": syn_uri,
|
||||
},
|
||||
"triples": {
|
||||
"question": q,
|
||||
"grounding": gnd,
|
||||
"exploration": exp,
|
||||
"focus": foc,
|
||||
"synthesis": syn,
|
||||
},
|
||||
"all": q + gnd + exp + foc + syn,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chain structure tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagProvenanceChain:
|
||||
"""Verify the full question → grounding → exploration → focus → synthesis chain."""
|
||||
|
||||
def test_chain_has_five_stages(self, chain):
|
||||
"""Each stage should produce at least some triples."""
|
||||
for stage in ["question", "grounding", "exploration", "focus", "synthesis"]:
|
||||
assert len(chain["triples"][stage]) > 0, f"{stage} produced no triples"
|
||||
|
||||
def test_derivation_chain(self, chain):
|
||||
"""
|
||||
The wasDerivedFrom links must form:
|
||||
grounding → question, exploration → grounding,
|
||||
focus → exploration, synthesis → focus.
|
||||
"""
|
||||
uris = chain["uris"]
|
||||
all_triples = chain["all"]
|
||||
|
||||
assert derived_from(all_triples, uris["grounding"]) == uris["question"]
|
||||
assert derived_from(all_triples, uris["exploration"]) == uris["grounding"]
|
||||
assert derived_from(all_triples, uris["focus"]) == uris["exploration"]
|
||||
assert derived_from(all_triples, uris["synthesis"]) == uris["focus"]
|
||||
|
||||
def test_question_has_no_parent(self, chain):
|
||||
"""The root question should not derive from anything (no parent_uri)."""
|
||||
uris = chain["uris"]
|
||||
all_triples = chain["all"]
|
||||
assert derived_from(all_triples, uris["question"]) is None
|
||||
|
||||
def test_question_with_parent(self):
|
||||
"""When a parent_uri is given, question should derive from it."""
|
||||
q_uri = question_uri("child-session")
|
||||
parent = "urn:trustgraph:agent:iteration:parent"
|
||||
q = question_triples(q_uri, "sub-query", "2026-01-01T00:00:00Z",
|
||||
parent_uri=parent)
|
||||
assert derived_from(q, q_uri) == parent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type annotation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagProvenanceTypes:
|
||||
"""Each stage must have the correct rdf:type annotations."""
|
||||
|
||||
def test_question_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["question"]
|
||||
assert has_type(triples, uris["question"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["question"], TG_GRAPH_RAG_QUESTION)
|
||||
|
||||
def test_grounding_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["grounding"]
|
||||
assert has_type(triples, uris["grounding"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["grounding"], TG_GROUNDING)
|
||||
|
||||
def test_exploration_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["exploration"]
|
||||
assert has_type(triples, uris["exploration"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["exploration"], TG_EXPLORATION)
|
||||
|
||||
def test_focus_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["focus"]
|
||||
assert has_type(triples, uris["focus"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["focus"], TG_FOCUS)
|
||||
|
||||
def test_synthesis_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["synthesis"]
|
||||
assert has_type(triples, uris["synthesis"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["synthesis"], TG_SYNTHESIS)
|
||||
assert has_type(triples, uris["synthesis"], TG_ANSWER_TYPE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domain-specific content tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagProvenanceContent:
|
||||
"""Each stage should carry the expected domain data."""
|
||||
|
||||
def test_question_has_query_text(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["question"], TG_QUERY, uris["question"])
|
||||
assert t is not None
|
||||
assert t.o.value == "What is quantum computing?"
|
||||
|
||||
def test_question_has_timestamp(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["question"], PROV_STARTED_AT_TIME, uris["question"])
|
||||
assert t is not None
|
||||
assert t.o.value == "2026-01-01T00:00:00Z"
|
||||
|
||||
def test_grounding_has_concepts(self, chain):
|
||||
uris = chain["uris"]
|
||||
concepts = find_triples(chain["triples"]["grounding"], TG_CONCEPT, uris["grounding"])
|
||||
concept_values = {t.o.value for t in concepts}
|
||||
assert concept_values == {"quantum", "computing"}
|
||||
|
||||
def test_exploration_has_edge_count(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["exploration"], TG_EDGE_COUNT, uris["exploration"])
|
||||
assert t is not None
|
||||
assert t.o.value == "42"
|
||||
|
||||
def test_exploration_has_entities(self, chain):
|
||||
uris = chain["uris"]
|
||||
entities = find_triples(chain["triples"]["exploration"], TG_ENTITY, uris["exploration"])
|
||||
entity_iris = {t.o.iri for t in entities}
|
||||
assert entity_iris == {"urn:entity:1", "urn:entity:2"}
|
||||
|
||||
def test_focus_has_selected_edges(self, chain):
|
||||
uris = chain["uris"]
|
||||
edges = find_triples(chain["triples"]["focus"], TG_SELECTED_EDGE, uris["focus"])
|
||||
assert len(edges) == 2
|
||||
|
||||
def test_focus_edges_have_quoted_triples(self, chain):
|
||||
"""Each edge selection entity should have a tg:edge with a quoted triple."""
|
||||
focus = chain["triples"]["focus"]
|
||||
edge_triples = find_triples(focus, TG_EDGE)
|
||||
assert len(edge_triples) == 2
|
||||
|
||||
# Each should have a quoted triple as the object
|
||||
for t in edge_triples:
|
||||
assert t.o.triple is not None, "tg:edge object should be a quoted triple"
|
||||
|
||||
def test_focus_edges_have_reasoning(self, chain):
|
||||
"""Each edge selection entity should have tg:reasoning."""
|
||||
focus = chain["triples"]["focus"]
|
||||
reasoning = find_triples(focus, TG_REASONING)
|
||||
assert len(reasoning) == 2
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert "Directly relevant to the query" in reasoning_texts
|
||||
assert "Provides the entity label" in reasoning_texts
|
||||
|
||||
def test_synthesis_has_document_ref(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["synthesis"], TG_DOCUMENT, uris["synthesis"])
|
||||
assert t is not None
|
||||
assert t.o.iri == "urn:doc:answer-1"
|
||||
|
||||
def test_synthesis_has_labels(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["synthesis"], RDFS_LABEL, uris["synthesis"])
|
||||
assert t is not None
|
||||
assert t.o.value == "Synthesis"
|
||||
|
|
@ -0,0 +1,380 @@
|
|||
"""
|
||||
Integration test: run a full DocumentRag.query() with mocked subsidiary
|
||||
clients and verify the explain_callback receives the complete provenance
|
||||
chain in the correct order with correct structure.
|
||||
|
||||
Document-RAG provenance chain (4 stages):
|
||||
question → grounding → exploration → synthesis
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT,
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMatch:
|
||||
"""Mimics the result from doc_embeddings_client.query()."""
|
||||
chunk_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
|
||||
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
|
||||
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
|
||||
CHUNK_B_CONTENT = "Refunds are processed to the original payment method."
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock clients for a document-rag query.
|
||||
|
||||
Client call sequence during query():
|
||||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. doc_embeddings_client.query(vector, ...) -> chunk matches
|
||||
4. fetch_chunk(chunk_id, user) -> chunk content
|
||||
5. prompt_client.document_prompt(query, documents) -> answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return "return policy\nrefund"
|
||||
return ""
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
# 2. Embedding vectors
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# 3. Chunk matching
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id=CHUNK_A),
|
||||
ChunkMatch(chunk_id=CHUNK_B),
|
||||
]
|
||||
|
||||
# 4. Chunk content
|
||||
async def mock_fetch(chunk_id, user):
|
||||
return {
|
||||
CHUNK_A: CHUNK_A_CONTENT,
|
||||
CHUNK_B: CHUNK_B_CONTENT,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
# 5. Synthesis
|
||||
prompt_client.document_prompt.return_value = (
|
||||
"Items can be returned within 30 days for a full refund."
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentRagQueryProvenance:
|
||||
"""
|
||||
Run a real DocumentRag.query() and verify the provenance chain emitted
|
||||
via explain_callback.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_four_events(self):
|
||||
"""query() should emit exactly 4 explain events."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert len(events) == 4, (
|
||||
f"Expected 4 explain events (question, grounding, exploration, "
|
||||
f"synthesis), got {len(events)}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_events_have_correct_types_in_order(self):
|
||||
"""
|
||||
Events should arrive as:
|
||||
question, grounding, exploration, synthesis.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_SYNTHESIS,
|
||||
]
|
||||
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
uri = events[i]["explain_id"]
|
||||
triples = events[i]["triples"]
|
||||
assert has_type(triples, uri, expected_type), (
|
||||
f"Event {i} (uri={uri}) should have type {expected_type}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_derivation_chain_links_correctly(self):
|
||||
"""
|
||||
Each event's URI should link to the previous via wasDerivedFrom:
|
||||
question → (none)
|
||||
grounding → question
|
||||
exploration → grounding
|
||||
synthesis → exploration
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
|
||||
# question has no parent
|
||||
assert derived_from(all_triples, uris[0]) is None
|
||||
|
||||
# grounding → question
|
||||
assert derived_from(all_triples, uris[1]) == uris[0]
|
||||
|
||||
# exploration → grounding
|
||||
assert derived_from(all_triples, uris[2]) == uris[1]
|
||||
|
||||
# synthesis → exploration
|
||||
assert derived_from(all_triples, uris[3]) == uris[2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_carries_query_text(self):
|
||||
"""The question event should contain the original query string."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
t = find_triple(q_triples, TG_QUERY, q_uri)
|
||||
assert t is not None
|
||||
assert t.o.value == "What is the return policy?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grounding_carries_concepts(self):
|
||||
"""The grounding event should list extracted concepts."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
gnd_triples = events[1]["triples"]
|
||||
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
|
||||
concept_values = {t.o.value for t in concepts}
|
||||
assert "return policy" in concept_values
|
||||
assert "refund" in concept_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_chunk_count(self):
|
||||
"""The exploration event should report the number of chunks retrieved."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
t = find_triple(exp_triples, TG_CHUNK_COUNT, exp_uri)
|
||||
assert t is not None
|
||||
assert int(t.o.value) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_selected_chunks(self):
|
||||
"""The exploration event should list the chunk IDs that were fetched."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
chunks = find_triples(exp_triples, TG_SELECTED_CHUNK, exp_uri)
|
||||
chunk_iris = {t.o.iri for t in chunks}
|
||||
assert CHUNK_A in chunk_iris
|
||||
assert CHUNK_B in chunk_iris
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
"""The synthesis event should have tg:Synthesis and tg:Answer types."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
syn_uri = events[3]["explain_id"]
|
||||
syn_triples = events[3]["triples"]
|
||||
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
|
||||
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_answer_text(self):
|
||||
"""query() should return the synthesised answer."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
"""query() without explain_callback should return answer normally."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(query="What is the return policy?")
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
"""All emitted triples should be in the urn:graph:retrieval graph."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
for t in event["triples"]:
|
||||
assert t.g == "urn:graph:retrieval", (
|
||||
f"Triple {t.s.iri} {t.p.iri} should be in "
|
||||
f"urn:graph:retrieval, got {t.g}"
|
||||
)
|
||||
|
|
@ -465,12 +465,15 @@ class TestQuery:
|
|||
return_value=(["entity1", "entity2"], ["concept1"])
|
||||
)
|
||||
|
||||
query.follow_edges_batch = AsyncMock(return_value={
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
})
|
||||
query.follow_edges_batch = AsyncMock(return_value=(
|
||||
{
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
},
|
||||
{}
|
||||
))
|
||||
|
||||
subgraph, entities, concepts = await query.get_subgraph("test query")
|
||||
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
|
||||
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
|
@ -503,7 +506,7 @@ class TestQuery:
|
|||
test_entities = ["entity1", "entity3"]
|
||||
test_concepts = ["concept1"]
|
||||
query.get_subgraph = AsyncMock(
|
||||
return_value=(test_subgraph, test_entities, test_concepts)
|
||||
return_value=(test_subgraph, {}, test_entities, test_concepts)
|
||||
)
|
||||
|
||||
async def mock_maybe_label(entity):
|
||||
|
|
|
|||
358
tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py
Normal file
358
tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
"""
|
||||
Tests that explain_triples are forwarded correctly through the graph-rag
|
||||
service and client layers.
|
||||
|
||||
Covers:
|
||||
- Service: explain messages include triples from the provenance callback
|
||||
- Client: explain_callback receives explain_triples from the response
|
||||
- End-to-end: triples survive the full service → client → callback chain
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.schema import (
|
||||
GraphRagQuery, GraphRagResponse,
|
||||
Triple, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base.graph_rag_client import GraphRagClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_triple(s_iri, p_iri, o_value, o_type=IRI):
|
||||
"""Create a Triple with IRI subject/predicate and typed object."""
|
||||
o = (
|
||||
Term(type=IRI, iri=o_value) if o_type == IRI
|
||||
else Term(type=LITERAL, value=o_value)
|
||||
)
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=s_iri),
|
||||
p=Term(type=IRI, iri=p_iri),
|
||||
o=o,
|
||||
)
|
||||
|
||||
|
||||
def sample_focus_triples():
|
||||
"""Focus-style triples with a quoted triple (edge selection)."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/Focus",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"http://www.w3.org/ns/prov#wasDerivedFrom",
|
||||
"urn:trustgraph:exploration:abc",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"https://trustgraph.ai/ns/selectedEdge",
|
||||
"urn:trustgraph:edge-sel:abc:0",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def sample_question_triples():
|
||||
"""Question-style triples."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/GraphRagQuestion",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc",
|
||||
"https://trustgraph.ai/ns/query",
|
||||
"What is quantum computing?",
|
||||
o_type=LITERAL,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service-level: explain messages carry triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagServiceExplainTriples:
|
||||
"""Test that the graph-rag service includes explain_triples in messages."""
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_messages_include_triples(self, mock_graph_rag_class):
|
||||
"""
|
||||
When the provenance callback is invoked with triples, the service
|
||||
should include them in the explain response message.
|
||||
"""
|
||||
from trustgraph.retrieval.graph_rag.rag import Processor
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
)
|
||||
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
|
||||
question_triples = sample_question_triples()
|
||||
focus_triples = sample_focus_triples()
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
explain_callback = kwargs.get('explain_callback')
|
||||
if explain_callback:
|
||||
await explain_callback(
|
||||
question_triples, "urn:trustgraph:question:abc"
|
||||
)
|
||||
await explain_callback(
|
||||
focus_triples, "urn:trustgraph:focus:abc"
|
||||
)
|
||||
return "The answer."
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is quantum computing?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_provenance = AsyncMock()
|
||||
|
||||
def flow_router(name):
|
||||
if name == "response":
|
||||
return mock_response
|
||||
if name == "explainability":
|
||||
return mock_provenance
|
||||
return AsyncMock()
|
||||
|
||||
flow.side_effect = flow_router
|
||||
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Find the explain messages
|
||||
explain_msgs = [
|
||||
call[0][0]
|
||||
for call in mock_response.send.call_args_list
|
||||
if call[0][0].message_type == "explain"
|
||||
]
|
||||
|
||||
assert len(explain_msgs) == 2
|
||||
|
||||
# First explain message should carry question triples
|
||||
assert explain_msgs[0].explain_id == "urn:trustgraph:question:abc"
|
||||
assert explain_msgs[0].explain_triples == question_triples
|
||||
|
||||
# Second explain message should carry focus triples
|
||||
assert explain_msgs[1].explain_id == "urn:trustgraph:focus:abc"
|
||||
assert explain_msgs[1].explain_triples == focus_triples
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client-level: explain_callback receives triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagClientExplainForwarding:
|
||||
"""Test that GraphRagClient.rag() forwards explain_triples to callback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_triples(self):
|
||||
"""
|
||||
The explain_callback should receive (explain_id, explain_graph,
|
||||
explain_triples) — not just (explain_id, explain_graph).
|
||||
"""
|
||||
focus_triples = sample_focus_triples()
|
||||
|
||||
# Simulate the response sequence the client would receive
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:focus:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=focus_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="The answer.",
|
||||
end_of_stream=True,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="",
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Capture what the explain_callback receives
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append({
|
||||
"explain_id": explain_id,
|
||||
"explain_graph": explain_graph,
|
||||
"explain_triples": explain_triples,
|
||||
})
|
||||
|
||||
# Patch self.request to feed responses to the recipient
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
result = await client.rag(
|
||||
query="test",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert result == "The answer."
|
||||
assert len(received_calls) == 1
|
||||
assert received_calls[0]["explain_id"] == "urn:trustgraph:focus:abc"
|
||||
assert received_calls[0]["explain_graph"] == "urn:graph:retrieval"
|
||||
assert received_calls[0]["explain_triples"] == focus_triples
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_empty_triples(self):
|
||||
"""
|
||||
When an explain event has no triples, the callback should still
|
||||
receive an empty list (not None or missing).
|
||||
"""
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=[],
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append(explain_triples)
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
await client.rag(query="test", explain_callback=explain_callback)
|
||||
|
||||
assert len(received_calls) == 1
|
||||
assert received_calls[0] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_explain_events_all_forward_triples(self):
|
||||
"""
|
||||
Each explain event in a session should forward its own triples.
|
||||
"""
|
||||
q_triples = sample_question_triples()
|
||||
f_triples = sample_focus_triples()
|
||||
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=q_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:focus:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=f_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append({
|
||||
"explain_id": explain_id,
|
||||
"explain_triples": explain_triples,
|
||||
})
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
await client.rag(query="test", explain_callback=explain_callback)
|
||||
|
||||
assert len(received_calls) == 2
|
||||
assert received_calls[0]["explain_id"] == "urn:trustgraph:question:abc"
|
||||
assert received_calls[0]["explain_triples"] == q_triples
|
||||
assert received_calls[1]["explain_id"] == "urn:trustgraph:focus:abc"
|
||||
assert received_calls[1]["explain_triples"] == f_triples
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_does_not_error(self):
|
||||
"""
|
||||
When no explain_callback is provided, explain events should be
|
||||
silently skipped without errors.
|
||||
"""
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_question_triples(),
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
result = await client.rag(query="test")
|
||||
assert result == "Answer."
|
||||
|
|
@ -0,0 +1,482 @@
|
|||
"""
|
||||
Integration test: run a full GraphRag.query() with mocked subsidiary clients
|
||||
and verify the explain_callback receives the complete provenance chain
|
||||
in the correct order with correct structure.
|
||||
|
||||
This tests the real query() method end-to-end, not just the triple builders.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
|
||||
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingMatch:
|
||||
"""Mimics the result from graph_embeddings_client.query()."""
|
||||
entity: Term
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# A tiny knowledge graph: 2 entities, 3 edges
|
||||
ENTITY_A = "http://example.com/QuantumComputing"
|
||||
ENTITY_B = "http://example.com/Physics"
|
||||
EDGE_1 = (ENTITY_A, "http://schema.org/relatedTo", ENTITY_B)
|
||||
EDGE_2 = (ENTITY_A, "http://schema.org/name", "Quantum Computing")
|
||||
EDGE_3 = (ENTITY_B, "http://schema.org/name", "Physics")
|
||||
|
||||
|
||||
def make_schema_triple(s, p, o):
|
||||
"""Create a SchemaTriple from string values."""
|
||||
return SchemaTriple(
|
||||
s=Term(type=IRI, iri=s),
|
||||
p=Term(type=IRI, iri=p),
|
||||
o=Term(type=IRI, iri=o) if o.startswith("http") else Term(type=LITERAL, value=o),
|
||||
)
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock clients that simulate a small knowledge graph query.
|
||||
|
||||
Client call sequence during query():
|
||||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. graph_embeddings_client.query(vector, ...) -> entity matches
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
|
||||
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
|
||||
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
|
||||
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
|
||||
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
prompt_responses = {}
|
||||
prompt_responses["extract-concepts"] = "quantum computing\nphysics"
|
||||
|
||||
# 2. Embedding vectors (simple fake vectors)
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# 3. Entity lookup - return our two entities
|
||||
graph_embeddings_client.query.return_value = [
|
||||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_A)),
|
||||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
|
||||
]
|
||||
|
||||
# 4. Triple queries (follow_edges_batch) - return our edges
|
||||
kg_triples = [
|
||||
make_schema_triple(*EDGE_1),
|
||||
make_schema_triple(*EDGE_2),
|
||||
make_schema_triple(*EDGE_3),
|
||||
]
|
||||
triples_client.query_stream.return_value = kg_triples
|
||||
|
||||
# 5. Label resolution - return entity as its own label (simplify)
|
||||
async def mock_label_query(s=None, p=None, o=None, limit=1,
|
||||
user=None, collection=None, g=None):
|
||||
return [] # No labels found, will fall back to URI
|
||||
triples_client.query.side_effect = mock_label_query
|
||||
|
||||
# 6+7. Edge scoring and reasoning: dynamically score/reason about
|
||||
# whatever edges the query method sends us, since edge IDs are computed
|
||||
# from str(Term) representations which include the full dataclass repr.
|
||||
synthesis_answer = "Quantum computing applies physics principles to computation."
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return prompt_responses["extract-concepts"]
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
elif template_id == "kg-synthesis":
|
||||
return synthesis_answer
|
||||
return ""
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagQueryProvenance:
|
||||
"""
|
||||
Run a real GraphRag.query() and verify the provenance chain emitted
|
||||
via explain_callback.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_five_events(self):
|
||||
"""query() should emit exactly 5 explain events."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0, # skip semantic pre-filter for simplicity
|
||||
)
|
||||
|
||||
assert len(events) == 5, (
|
||||
f"Expected 5 explain events (question, grounding, exploration, "
|
||||
f"focus, synthesis), got {len(events)}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_events_have_correct_types_in_order(self):
|
||||
"""
|
||||
Events should arrive as:
|
||||
question, grounding, exploration, focus, synthesis.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
TG_GRAPH_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_FOCUS,
|
||||
TG_SYNTHESIS,
|
||||
]
|
||||
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
uri = events[i]["explain_id"]
|
||||
triples = events[i]["triples"]
|
||||
assert has_type(triples, uri, expected_type), (
|
||||
f"Event {i} (uri={uri}) should have type {expected_type}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_derivation_chain_links_correctly(self):
|
||||
"""
|
||||
Each event's URI should link to the previous via wasDerivedFrom:
|
||||
grounding → question → (none)
|
||||
exploration → grounding
|
||||
focus → exploration
|
||||
synthesis → focus
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
|
||||
# question has no parent
|
||||
assert derived_from(all_triples, uris[0]) is None
|
||||
|
||||
# grounding → question
|
||||
assert derived_from(all_triples, uris[1]) == uris[0]
|
||||
|
||||
# exploration → grounding
|
||||
assert derived_from(all_triples, uris[2]) == uris[1]
|
||||
|
||||
# focus → exploration
|
||||
assert derived_from(all_triples, uris[3]) == uris[2]
|
||||
|
||||
# synthesis → focus
|
||||
assert derived_from(all_triples, uris[4]) == uris[3]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_event_carries_query_text(self):
|
||||
"""The question event should contain the original query string."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
t = find_triple(q_triples, TG_QUERY, q_uri)
|
||||
assert t is not None
|
||||
assert t.o.value == "What is quantum computing?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grounding_carries_concepts(self):
|
||||
"""The grounding event should list extracted concepts."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
gnd_triples = events[1]["triples"]
|
||||
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
|
||||
concept_values = {t.o.value for t in concepts}
|
||||
assert "quantum computing" in concept_values
|
||||
assert "physics" in concept_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_edge_count(self):
|
||||
"""The exploration event should report how many edges were found."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
t = find_triple(exp_triples, TG_EDGE_COUNT, exp_uri)
|
||||
assert t is not None
|
||||
# Should be non-zero (we provided 3 edges, label edges filtered)
|
||||
assert int(t.o.value) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_focus_has_selected_edges_with_reasoning(self):
|
||||
"""
|
||||
The focus event should carry selected edges as quoted triples
|
||||
with reasoning text.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
foc_uri = events[3]["explain_id"]
|
||||
foc_triples = events[3]["triples"]
|
||||
|
||||
# Should have selected edges
|
||||
selected = find_triples(foc_triples, TG_SELECTED_EDGE, foc_uri)
|
||||
assert len(selected) > 0, "Focus should have at least one selected edge"
|
||||
|
||||
# Each edge selection should have a quoted triple
|
||||
edge_t = find_triples(foc_triples, TG_EDGE)
|
||||
assert len(edge_t) > 0, "Focus should have tg:edge with quoted triples"
|
||||
for t in edge_t:
|
||||
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
|
||||
|
||||
# Should have reasoning
|
||||
reasoning = find_triples(foc_triples, TG_REASONING)
|
||||
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
"""The synthesis event should have tg:Answer type."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
syn_uri = events[4]["explain_id"]
|
||||
syn_triples = events[4]["triples"]
|
||||
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
|
||||
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_answer_text(self):
|
||||
"""query() should still return the synthesised answer."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
result = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parent_uri_links_question_to_parent(self):
|
||||
"""When parent_uri is provided, question should derive from it."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
parent = "urn:trustgraph:agent:iteration:xyz"
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
parent_uri=parent,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
assert derived_from(q_triples, q_uri) == parent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
"""query() without explain_callback should return answer normally."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
"""All emitted triples should be in the urn:graph:retrieval graph."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
for t in event["triples"]:
|
||||
assert t.g == "urn:graph:retrieval", (
|
||||
f"Triple {t.s.iri} {t.p.iri} should be in "
|
||||
f"urn:graph:retrieval, got {t.g}"
|
||||
)
|
||||
|
|
@ -15,7 +15,7 @@ class GraphRagClient(RequestResponse):
|
|||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional async callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications
|
||||
explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
|
|
@ -30,7 +30,7 @@ class GraphRagClient(RequestResponse):
|
|||
# Handle explain notifications
|
||||
if resp.message_type == 'explain':
|
||||
if explain_callback and resp.explain_id:
|
||||
await explain_callback(resp.explain_id, resp.explain_graph)
|
||||
await explain_callback(resp.explain_id, resp.explain_graph, resp.explain_triples)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class DocumentRagClient(BaseClient):
|
|||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
|
||||
explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
|
|
@ -55,7 +55,7 @@ class DocumentRagClient(BaseClient):
|
|||
# Handle explain notifications (response is None/empty, explain_id present)
|
||||
if x.explain_id and not x.response:
|
||||
if explain_callback:
|
||||
explain_callback(x.explain_id, x.explain_graph)
|
||||
explain_callback(x.explain_id, x.explain_graph, x.explain_triples)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class GraphRagClient(BaseClient):
|
|||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
|
||||
explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
|
|
@ -59,7 +59,7 @@ class GraphRagClient(BaseClient):
|
|||
# Handle explain notifications
|
||||
if x.message_type == 'explain':
|
||||
if explain_callback and x.explain_id:
|
||||
explain_callback(x.explain_id, x.explain_graph)
|
||||
explain_callback(x.explain_id, x.explain_graph, x.explain_triples)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
|
|
|
|||
|
|
@ -465,11 +465,18 @@ def exploration_triples(
|
|||
return triples
|
||||
|
||||
|
||||
def _quoted_triple(s: str, p: str, o: str) -> Term:
|
||||
"""Create a quoted triple term (RDF-star) from string values."""
|
||||
def _quoted_triple(s, p, o) -> Term:
|
||||
"""Create a quoted triple term (RDF-star).
|
||||
|
||||
Accepts either Term objects (preserving original types) or plain
|
||||
strings (treated as IRIs for backward compatibility).
|
||||
"""
|
||||
s_term = s if isinstance(s, Term) else _iri(s)
|
||||
p_term = p if isinstance(p, Term) else _iri(p)
|
||||
o_term = o if isinstance(o, Term) else _iri(o)
|
||||
return Term(
|
||||
type=TRIPLE,
|
||||
triple=Triple(s=_iri(s), p=_iri(p), o=_iri(o))
|
||||
triple=Triple(s=s_term, p=p_term, o=o_term)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,13 +39,14 @@ class KnowledgeQueryImpl:
|
|||
if respond:
|
||||
from ... schema import AgentResponse
|
||||
|
||||
async def explain_callback(explain_id, explain_graph):
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples=None):
|
||||
self.context.last_sub_explain_uri = explain_id
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=explain_id,
|
||||
explain_graph=explain_graph,
|
||||
explain_triples=explain_triples or [],
|
||||
))
|
||||
|
||||
if current_uri:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from collections import OrderedDict
|
|||
from datetime import datetime
|
||||
|
||||
from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
|
||||
from ... knowledge import Uri, Literal
|
||||
|
||||
# Provenance imports
|
||||
from trustgraph.provenance import (
|
||||
|
|
@ -46,6 +47,26 @@ def term_to_string(term):
|
|||
return term.iri or term.value or str(term)
|
||||
|
||||
|
||||
def to_term(val):
|
||||
"""Convert a Uri, Literal, or string to a schema Term.
|
||||
|
||||
The triples client returns Uri/Literal (str subclasses) rather than
|
||||
Term objects. This converts them back so provenance quoted triples
|
||||
preserve the correct type.
|
||||
"""
|
||||
if isinstance(val, Term):
|
||||
return val
|
||||
if isinstance(val, Uri):
|
||||
return Term(type=IRI, iri=str(val))
|
||||
if isinstance(val, Literal):
|
||||
return Term(type=LITERAL, value=str(val))
|
||||
# Fallback: treat as IRI if it looks like one, otherwise literal
|
||||
s = str(val)
|
||||
if s.startswith(("http://", "https://", "urn:")):
|
||||
return Term(type=IRI, iri=s)
|
||||
return Term(type=LITERAL, value=s)
|
||||
|
||||
|
||||
def edge_id(s, p, o):
|
||||
"""Generate an 8-character hash ID for an edge (s, p, o)."""
|
||||
edge_str = f"{s}|{p}|{o}"
|
||||
|
|
@ -258,10 +279,18 @@ class Query:
|
|||
return all_triples
|
||||
|
||||
async def follow_edges_batch(self, entities, max_depth):
|
||||
"""Optimized iterative graph traversal with batching"""
|
||||
"""Optimized iterative graph traversal with batching.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map) where subgraph is a set of
|
||||
(str, str, str) tuples and term_map maps each string tuple
|
||||
to its original (Term, Term, Term) for type-preserving
|
||||
provenance.
|
||||
"""
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
term_map = {} # (str, str, str) -> (Term, Term, Term)
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
|
|
@ -282,6 +311,7 @@ class Query:
|
|||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
subgraph.add(triple_tuple)
|
||||
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
|
||||
|
||||
# Collect entities for next level (only from s and o positions)
|
||||
if depth < max_depth - 1: # Don't collect for final depth
|
||||
|
|
@ -293,13 +323,13 @@ class Query:
|
|||
|
||||
# Stop if subgraph size limit reached
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return subgraph
|
||||
return subgraph, term_map
|
||||
|
||||
# Update for next iteration
|
||||
visited.update(current_level)
|
||||
current_level = next_level
|
||||
|
||||
return subgraph
|
||||
return subgraph, term_map
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
"""Legacy method - replaced by follow_edges_batch"""
|
||||
|
|
@ -311,7 +341,7 @@ class Query:
|
|||
return
|
||||
|
||||
# For backward compatibility, convert to new approach
|
||||
batch_result = await self.follow_edges_batch([ent], path_length)
|
||||
batch_result, _ = await self.follow_edges_batch([ent], path_length)
|
||||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
|
|
@ -319,9 +349,10 @@ class Query:
|
|||
Get subgraph by extracting concepts, finding entities, and traversing.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, entities, concepts) where subgraph is a list of
|
||||
(s, p, o) tuples, entities is the seed entity list, and concepts
|
||||
is the extracted concept list.
|
||||
tuple: (subgraph, term_map, entities, concepts) where subgraph is
|
||||
a list of (s, p, o) string tuples, term_map maps each string
|
||||
tuple to its original (Term, Term, Term), entities is the seed
|
||||
entity list, and concepts is the extracted concept list.
|
||||
"""
|
||||
|
||||
entities, concepts = await self.get_entities(query)
|
||||
|
|
@ -330,9 +361,9 @@ class Query:
|
|||
logger.debug("Getting subgraph...")
|
||||
|
||||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
return list(subgraph), entities, concepts
|
||||
return list(subgraph), term_map, entities, concepts
|
||||
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
|
|
@ -353,7 +384,7 @@ class Query:
|
|||
- entities: list of seed entity URI strings
|
||||
- concepts: list of concept strings extracted from query
|
||||
"""
|
||||
subgraph, entities, concepts = await self.get_subgraph(query)
|
||||
subgraph, term_map, entities, concepts = await self.get_subgraph(query)
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
|
|
@ -377,7 +408,7 @@ class Query:
|
|||
|
||||
# Apply labels to subgraph and build URI mapping
|
||||
labeled_edges = []
|
||||
uri_map = {} # Maps edge_id of labeled edge -> original URI triple
|
||||
uri_map = {} # Maps edge_id of labeled edge -> original Term triple
|
||||
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
|
|
@ -387,9 +418,9 @@ class Query:
|
|||
)
|
||||
labeled_edges.append(labeled_triple)
|
||||
|
||||
# Map from labeled edge ID to original URIs
|
||||
# Map from labeled edge ID to original Terms (preserving types)
|
||||
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
|
||||
uri_map[labeled_eid] = (s, p, o)
|
||||
uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
|
||||
|
||||
labeled_edges = labeled_edges[0:self.max_subgraph_size]
|
||||
|
||||
|
|
@ -419,12 +450,14 @@ class Query:
|
|||
# Step 1: Find subgraphs containing these edges via tg:contains
|
||||
subgraph_tasks = []
|
||||
for s, p, o in edge_uris:
|
||||
# s, p, o may be Term objects (preserving types) or strings
|
||||
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
|
||||
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
|
||||
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
|
||||
quoted = Term(
|
||||
type=TRIPLE,
|
||||
triple=SchemaTriple(
|
||||
s=Term(type=IRI, iri=s),
|
||||
p=Term(type=IRI, iri=p),
|
||||
o=Term(type=IRI, iri=o),
|
||||
s=s_term, p=p_term, o=o_term,
|
||||
)
|
||||
)
|
||||
subgraph_tasks.append(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue