mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
GraphRAG Query-Time Explainability (#677)
Implements full explainability pipeline for GraphRAG queries, enabling
traceability from answers back to source documents.
Renamed throughout for clarity:
- provenance_callback → explain_callback
- provenance_id → explain_id
- provenance_collection → explain_collection
- message_type "provenance" → "explain"
- Queue name "provenance" → "explainability"
GraphRAG queries now emit explainability events as they execute:
1. Session - query text and timestamp
2. Retrieval - edges retrieved from subgraph
3. Selection - selected edges with LLM reasoning (JSONL with id +
reasoning)
4. Answer - reference to synthesized response
Events stream via explain_callback during query(), enabling
real-time UX.
- Answers stored in librarian service (not inline in graph - too large)
- Document ID as URN: urn:trustgraph:answer:{session_id}
- Graph stores tg:document reference (IRI) to librarian document
- Added librarian producer/consumer to graph-rag service
- get_labelgraph() now returns (labeled_edges, uri_map)
- uri_map maps edge_id(label_s, label_p, label_o) →
(uri_s, uri_p, uri_o)
- Explainability data stores original URIs, not labels
- Enables tracing edges back to reifying statements via tg:reifies
- Added serialize_triple() to query service (matches storage format)
- get_term_value() now handles TRIPLE type terms
- Enables querying by quoted triple in object position:
?stmt tg:reifies <<s p o>>
- Displays real-time explainability events during query
- Resolves rdfs:label for edge components (s, p, o)
- Traces source chain via prov:wasDerivedFrom to root document
- Output: "Source: Chunk 1 → Page 2 → Document Title"
- Label caching to avoid repeated queries
GraphRagResponse:
- explain_id: str | None
- explain_collection: str | None
- message_type: str ("chunk" or "explain")
- end_of_session: bool
trustgraph-base/trustgraph/provenance/:
- namespaces.py - Added TG_DOCUMENT predicate
- triples.py - answer_triples() supports document_id reference
- uris.py - Added edge_selection_uri()
trustgraph-base/trustgraph/schema/services/retrieval.py:
- GraphRagResponse with explain_id, explain_collection, end_of_session
trustgraph-flow/trustgraph/retrieval/graph_rag/:
- graph_rag.py - URI preservation, streaming answer accumulation
- rag.py - Librarian integration, real-time explain emission
trustgraph-flow/trustgraph/query/triples/cassandra/service.py:
- Quoted triple serialization for query matching
trustgraph-cli/trustgraph/cli/invoke_graph_rag.py:
- Full explainability display with label resolution and source tracing
This commit is contained in:
parent
d2d71f859d
commit
7a6197d8c3
24 changed files with 2001 additions and 323 deletions
|
|
@ -4,18 +4,28 @@ Simple RAG service, performs query using graph RAG an LLM.
|
|||
Input is query, output is response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from ... schema import GraphRagQuery, GraphRagResponse, Error
|
||||
from ... schema import Triples, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from . graph_rag import GraphRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "graph-rag"
|
||||
default_concurrency = 1
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -28,6 +38,7 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
explainability_collection = params.get("explainability_collection", "explainability")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -37,6 +48,7 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"explainability_collection": explainability_collection,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -44,6 +56,7 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.explainability_collection = explainability_collection
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
# Each user/collection combination MUST have isolated data access
|
||||
|
|
@ -93,10 +106,163 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for storing answer content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_librarian_requests = {}
|
||||
|
||||
logger.info("Graph RAG service initialized")
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_librarian_requests:
|
||||
future = self.pending_librarian_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
Save answer content to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the answer document
|
||||
user: User ID
|
||||
content: Answer text content
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or "GraphRAG Answer",
|
||||
document_type="answer",
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_librarian_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_librarian_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling input {id}...")
|
||||
|
||||
# Track explainability refs for end_of_session signaling
|
||||
explainability_refs_emitted = []
|
||||
|
||||
# Real-time explainability callback - emits triples and IDs as they're generated
|
||||
async def send_explainability(triples, explain_id):
|
||||
# Send triples to explainability queue
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=explain_id,
|
||||
metadata=[],
|
||||
user=v.user,
|
||||
collection=self.explainability_collection,
|
||||
),
|
||||
triples=triples,
|
||||
))
|
||||
|
||||
# Send explain ID and collection to response queue
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id=explain_id,
|
||||
explain_collection=self.explainability_collection,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
explainability_refs_emitted.append(explain_id)
|
||||
|
||||
# CRITICAL SECURITY: Create new GraphRag instance per request
|
||||
# This ensures proper isolation between users and collections
|
||||
# Flow clients are request-scoped and must not be shared
|
||||
|
|
@ -108,13 +274,6 @@ class Processor(FlowProcessor):
|
|||
verbose=True,
|
||||
)
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling input {id}...")
|
||||
|
||||
if v.entity_limit:
|
||||
entity_limit = v.entity_limit
|
||||
else:
|
||||
|
|
@ -135,6 +294,15 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
# Callback to save answer content to librarian
|
||||
async def save_answer(doc_id, answer_text):
|
||||
await self.save_answer_content(
|
||||
doc_id=doc_id,
|
||||
user=v.user,
|
||||
content=answer_text,
|
||||
title=f"GraphRAG Answer: {v.query[:50]}...",
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
|
|
@ -142,6 +310,7 @@ class Processor(FlowProcessor):
|
|||
async def send_chunk(chunk, end_of_stream):
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response=chunk,
|
||||
end_of_stream=end_of_stream,
|
||||
error=None
|
||||
|
|
@ -149,34 +318,50 @@ class Processor(FlowProcessor):
|
|||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Query with streaming enabled
|
||||
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||
await rag.query(
|
||||
# Query with streaming and real-time explain
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
)
|
||||
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
# Non-streaming path with real-time explain
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
)
|
||||
|
||||
# Send chunk with response
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = response,
|
||||
end_of_stream = True,
|
||||
error = None
|
||||
message_type="chunk",
|
||||
response=response,
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
),
|
||||
properties = {"id": id}
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Send final message to close session
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="",
|
||||
end_of_session=True,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.info("Request processing complete")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -185,22 +370,18 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug("Sending error response...")
|
||||
|
||||
# Send error response with end_of_stream flag if streaming was requested
|
||||
error_response = GraphRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "graph-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
)
|
||||
|
||||
# If streaming was requested, indicate stream end
|
||||
if v.streaming:
|
||||
error_response.end_of_stream = True
|
||||
|
||||
# Send error response and close session
|
||||
await flow("response").send(
|
||||
error_response,
|
||||
properties = {"id": id}
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
error=Error(
|
||||
type="graph-rag-error",
|
||||
message=str(e),
|
||||
),
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -243,6 +424,12 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--explainability-collection',
|
||||
default='explainability',
|
||||
help=f'Collection for storing explainability triples (default: explainability)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue