GraphRAG Query-Time Explainability (#677)

Implements full explainability pipeline for GraphRAG queries, enabling
traceability from answers back to source documents.

Renamed throughout for clarity:
- provenance_callback → explain_callback
- provenance_id → explain_id
- provenance_collection → explain_collection
- message_type "provenance" → "explain"
- Queue name "provenance" → "explainability"

GraphRAG queries now emit explainability events as they execute:
1. Session - query text and timestamp
2. Retrieval - edges retrieved from subgraph
3. Selection - selected edges with LLM reasoning (JSONL with id +
   reasoning)
4. Answer - reference to synthesized response

Events stream via explain_callback during query(), enabling
real-time UX.

- Answers stored in librarian service (not inline in graph - too large)
- Document ID as URN: urn:trustgraph:answer:{session_id}
- Graph stores tg:document reference (IRI) to librarian document
- Added librarian producer/consumer to graph-rag service

- get_labelgraph() now returns (labeled_edges, uri_map)
- uri_map maps edge_id(label_s, label_p, label_o) →
  (uri_s, uri_p, uri_o)
- Explainability data stores original URIs, not labels
- Enables tracing edges back to reifying statements via tg:reifies

- Added serialize_triple() to query service (matches storage format)
- get_term_value() now handles TRIPLE type terms
- Enables querying by quoted triple in object position:
  ?stmt tg:reifies <<s p o>>

- Displays real-time explainability events during query
- Resolves rdfs:label for edge components (s, p, o)
- Traces source chain via prov:wasDerivedFrom to root document
- Output: "Source: Chunk 1 → Page 2 → Document Title"
- Label caching to avoid repeated queries

GraphRagResponse:
- explain_id: str | None
- explain_collection: str | None
- message_type: str ("chunk" or "explain")
- end_of_session: bool

trustgraph-base/trustgraph/provenance/:
- namespaces.py - Added TG_DOCUMENT predicate
- triples.py - answer_triples() supports document_id reference
- uris.py - Added edge_selection_uri()

trustgraph-base/trustgraph/schema/services/retrieval.py:
- GraphRagResponse with explain_id, explain_collection, end_of_session

trustgraph-flow/trustgraph/retrieval/graph_rag/:
- graph_rag.py - URI preservation, streaming answer accumulation
- rag.py - Librarian integration, real-time explain emission

trustgraph-flow/trustgraph/query/triples/cassandra/service.py:
- Quoted triple serialization for query matching

trustgraph-cli/trustgraph/cli/invoke_graph_rag.py:
- Full explainability display with label resolution and source tracing
This commit is contained in:
cybermaggedon 2026-03-10 10:00:01 +00:00 committed by GitHub
parent d2d71f859d
commit 7a6197d8c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 2001 additions and 323 deletions

View file

@ -1,11 +1,27 @@
import asyncio
import hashlib
import json
import logging
import time
import uuid
from collections import OrderedDict
from datetime import datetime
from ... schema import IRI, LITERAL
# Provenance imports
from trustgraph.provenance import (
query_session_uri,
retrieval_uri as make_retrieval_uri,
selection_uri as make_selection_uri,
answer_uri as make_answer_uri,
query_session_triples,
retrieval_triples,
selection_triples,
answer_triples,
)
# Module logger
logger = logging.getLogger(__name__)
@ -23,6 +39,12 @@ def term_to_string(term):
# Fallback
return term.iri or term.value or str(term)
def edge_id(s, p, o):
"""Generate an 8-character hash ID for an edge (s, p, o)."""
edge_str = f"{s}|{p}|{o}"
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
@ -258,7 +280,14 @@ class Query:
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
"""
Get subgraph with labels resolved for display.
Returns:
tuple: (labeled_edges, uri_map) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
"""
subgraph = await self.get_subgraph(query)
# Filter out label triples
@ -281,27 +310,33 @@ class Query:
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph
sg2 = []
# Apply labels to subgraph and build URI mapping
labeled_edges = []
uri_map = {} # Maps edge_id of labeled edge -> original URI triple
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
)
sg2.append(labeled_triple)
labeled_edges.append(labeled_triple)
sg2 = sg2[0:self.max_subgraph_size]
# Map from labeled edge ID to original URIs
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
uri_map[labeled_eid] = (s, p, o)
labeled_edges = labeled_edges[0:self.max_subgraph_size]
if self.verbose:
logger.debug("Subgraph:")
for edge in sg2:
for edge in labeled_edges:
logger.debug(f" {str(edge)}")
if self.verbose:
logger.debug("Done.")
return sg2
return labeled_edges, uri_map
class GraphRag:
"""
@ -335,11 +370,44 @@ class GraphRag:
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, streaming = False, chunk_callback = None,
explain_callback = None, save_answer_callback = None,
):
"""
Execute a GraphRAG query with real-time explainability tracking.
Args:
query: The query string
user: User identifier
collection: Collection identifier
entity_limit: Max entities to retrieve
triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
str: The synthesized answer text
"""
if self.verbose:
logger.debug("Constructing prompt...")
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
session_uri = query_session_uri(session_id)
ret_uri = make_retrieval_uri(session_id)
sel_uri = make_selection_uri(session_id)
ans_uri = make_answer_uri(session_id)
timestamp = datetime.utcnow().isoformat() + "Z"
# Emit session explainability immediately
if explain_callback:
session_triples = query_session_triples(session_uri, query, timestamp)
await explain_callback(session_triples, session_uri)
q = Query(
rag = self, user = user, collection = collection,
verbose = self.verbose, entity_limit = entity_limit,
@ -348,24 +416,171 @@ class GraphRag:
max_path_length = max_path_length,
)
kg = await q.get_labelgraph(query)
kg, uri_map = await q.get_labelgraph(query)
# Emit retrieval explain after graph retrieval completes
if explain_callback:
ret_triples = retrieval_triples(ret_uri, session_uri, len(kg))
await explain_callback(ret_triples, ret_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
if streaming and chunk_callback:
resp = await self.prompt_client.kg_prompt(
query, kg,
streaming=True,
chunk_callback=chunk_callback
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {}
edges_with_ids = []
for s, p, o in kg:
eid = edge_id(s, p, o)
edge_map[eid] = (s, p, o)
edges_with_ids.append({
"id": eid,
"s": s,
"p": p,
"o": o
})
if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1: Edge Selection - LLM selects relevant edges with reasoning
selection_response = await self.prompt_client.prompt(
"kg-edge-selection",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
if self.verbose:
logger.debug(f"Edge selection response: {selection_response}")
# Parse response to get selected edge IDs and reasoning
# Response can be a string (JSONL) or a list (JSON array)
selected_ids = set()
selected_edges_with_reasoning = [] # For explain
if isinstance(selection_response, list):
# JSON array response
for obj in selection_response:
if isinstance(obj, dict) and "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
elif isinstance(selection_response, str):
# JSONL string response
for line in selection_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
if "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
except json.JSONDecodeError:
logger.warning(f"Failed to parse edge selection line: {line}")
continue
if self.verbose:
logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}")
# Filter to selected edges
selected_edges = []
for eid in selected_ids:
if eid in edge_map:
selected_edges.append(edge_map[eid])
if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges")
# Emit selection explain after edge selection completes
if explain_callback:
sel_triples = selection_triples(
sel_uri, ret_uri, selected_edges_with_reasoning, session_id
)
await explain_callback(sel_triples, sel_uri)
# Step 2: Synthesis - LLM generates answer from selected edges only
selected_edge_dicts = [
{"s": s, "p": p, "o": o}
for s, p, o in selected_edges
]
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
await self.prompt_client.prompt(
"kg-synthesis",
variables={
"query": query,
"knowledge": selected_edge_dicts
},
streaming=True,
chunk_callback=accumulating_callback
)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.kg_prompt(query, kg)
resp = await self.prompt_client.prompt(
"kg-synthesis",
variables={
"query": query,
"knowledge": selected_edge_dicts
}
)
if self.verbose:
logger.debug("Query processing complete")
# Emit answer explain after synthesis completes
if explain_callback:
answer_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian if callback provided
if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
answer_doc_id = f"urn:trustgraph:answer:{session_id}"
try:
await save_answer_callback(answer_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
answer_doc_id = None # Fall back to inline content
# Generate triples with document reference or inline content
ans_triples = answer_triples(
ans_uri, sel_uri,
answer_text="" if answer_doc_id else answer_text,
document_id=answer_doc_id,
)
await explain_callback(ans_triples, ans_uri)
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
return resp

View file

@ -4,18 +4,28 @@ Simple RAG service, performs query using graph RAG an LLM.
Input is query, output is response.
"""
import asyncio
import base64
import logging
import uuid
from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "graph-rag"
default_concurrency = 1
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor):
@ -28,6 +38,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
explainability_collection = params.get("explainability_collection", "explainability")
super(Processor, self).__init__(
**params | {
@ -37,6 +48,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"explainability_collection": explainability_collection,
}
)
@ -44,6 +56,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.explainability_collection = explainability_collection
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
@ -93,10 +106,163 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
ProducerSpec(
name = "explainability",
schema = Triples,
)
)
# Librarian client for storing answer content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
logger.info("Graph RAG service initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "GraphRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling input {id}...")
# Track explainability refs for end_of_session signaling
explainability_refs_emitted = []
# Real-time explainability callback - emits triples and IDs as they're generated
async def send_explainability(triples, explain_id):
# Send triples to explainability queue
await flow("explainability").send(Triples(
metadata=Metadata(
id=explain_id,
metadata=[],
user=v.user,
collection=self.explainability_collection,
),
triples=triples,
))
# Send explain ID and collection to response queue
await flow("response").send(
GraphRagResponse(
message_type="explain",
explain_id=explain_id,
explain_collection=self.explainability_collection,
),
properties={"id": id}
)
explainability_refs_emitted.append(explain_id)
# CRITICAL SECURITY: Create new GraphRag instance per request
# This ensures proper isolation between users and collections
# Flow clients are request-scoped and must not be shared
@ -108,13 +274,6 @@ class Processor(FlowProcessor):
verbose=True,
)
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling input {id}...")
if v.entity_limit:
entity_limit = v.entity_limit
else:
@ -135,6 +294,15 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
content=answer_text,
title=f"GraphRAG Answer: {v.query[:50]}...",
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
@ -142,6 +310,7 @@ class Processor(FlowProcessor):
async def send_chunk(chunk, end_of_stream):
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response=chunk,
end_of_stream=end_of_stream,
error=None
@ -149,34 +318,50 @@ class Processor(FlowProcessor):
properties={"id": id}
)
# Query with streaming enabled
# All chunks (including final one with end_of_stream=True) are sent via callback
await rag.query(
# Query with streaming and real-time explain
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
streaming = True,
chunk_callback = send_chunk,
explain_callback = send_explainability,
save_answer_callback = save_answer,
)
else:
# Non-streaming path (existing behavior)
# Non-streaming path with real-time explain
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
explain_callback = send_explainability,
save_answer_callback = save_answer,
)
# Send chunk with response
await flow("response").send(
GraphRagResponse(
response = response,
end_of_stream = True,
error = None
message_type="chunk",
response=response,
end_of_stream=True,
error=None,
),
properties = {"id": id}
properties={"id": id}
)
# Send final message to close session
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
),
properties={"id": id}
)
logger.info("Request processing complete")
except Exception as e:
@ -185,22 +370,18 @@ class Processor(FlowProcessor):
logger.debug("Sending error response...")
# Send error response with end_of_stream flag if streaming was requested
error_response = GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
)
# If streaming was requested, indicate stream end
if v.streaming:
error_response.end_of_stream = True
# Send error response and close session
await flow("response").send(
error_response,
properties = {"id": id}
GraphRagResponse(
message_type="chunk",
error=Error(
type="graph-rag-error",
message=str(e),
),
end_of_stream=True,
end_of_session=True,
),
properties={"id": id}
)
@staticmethod
@ -243,6 +424,12 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--explainability-collection',
default='explainability',
help=f'Collection for storing explainability triples (default: explainability)'
)
def run():
Processor.launch(default_ident, __doc__)