Deliver explainability triples inline in retrieval response stream (#763)

Provenance triples are now included directly in explain messages from
GraphRAG, DocumentRAG, and Agent services, eliminating the need for
follow-up knowledge graph queries to retrieve explainability details.

Each explain message in the response stream now carries:
- explain_id: root URI for this provenance step (unchanged)
- explain_graph: named graph where triples are stored (unchanged)
- explain_triples: the actual provenance triples for this step (new)

Changes across the stack:
- Schema: added explain_triples field to GraphRagResponse,
  DocumentRagResponse, and AgentResponse
- Services: all explain message call sites pass triples through
  (graph_rag, document_rag, agent react, agent orchestrator)
- Translators: encode explain_triples via TripleTranslator for
  gateway wire format
- Python SDK: ProvenanceEvent now includes parsed ExplainEntity
  and raw triples; expanded event_type detection
- CLI: invoke_graph_rag, invoke_agent, invoke_document_rag use
  inline entity when available, fall back to graph query
- Tech specs updated

Additional explainability test
This commit is contained in:
cybermaggedon 2026-04-07 12:19:05 +01:00 committed by GitHub
parent 2f8d6a3ffb
commit ddd4bd7790
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 521 additions and 49 deletions

View file

@ -366,19 +366,13 @@ class SocketClient:
# Handle GraphRAG/DocRAG message format with message_type
if message_type == "explain":
if include_provenance:
return ProvenanceEvent(
explain_id=resp.get("explain_id", ""),
explain_graph=resp.get("explain_graph", "")
)
return self._build_provenance_event(resp)
return None
# Handle Agent message format with chunk_type="explain"
if chunk_type == "explain":
if include_provenance:
return ProvenanceEvent(
explain_id=resp.get("explain_id", ""),
explain_graph=resp.get("explain_graph", "")
)
return self._build_provenance_event(resp)
return None
if chunk_type == "thought":
@ -413,6 +407,42 @@ class SocketClient:
error=None
)
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
"""Build a ProvenanceEvent from a response dict, parsing inline triples
into an ExplainEntity if available."""
explain_id = resp.get("explain_id", "")
explain_graph = resp.get("explain_graph", "")
raw_triples = resp.get("explain_triples", [])
entity = None
if raw_triples:
try:
from .explainability import ExplainEntity
# Convert wire-format triple dicts to (s, p, o) tuples
parsed = []
for t in raw_triples:
s = t.get("s", {}).get("i", "") if t.get("s") else ""
p = t.get("p", {}).get("i", "") if t.get("p") else ""
o_term = t.get("o", {})
if o_term:
if o_term.get("t") == "i":
o = o_term.get("i", "")
else:
o = o_term.get("v", "")
else:
o = ""
parsed.append((s, p, o))
entity = ExplainEntity.from_triples(explain_id, parsed)
except Exception:
pass
return ProvenanceEvent(
explain_id=explain_id,
explain_graph=explain_graph,
entity=entity,
triples=raw_triples,
)
def close(self) -> None:
"""Close the persistent WebSocket connection."""
if self._loop and not self._loop.is_closed():

View file

@ -213,25 +213,47 @@ class ProvenanceEvent:
"""
Provenance event for explainability.
Emitted during GraphRAG queries when explainable mode is enabled.
Emitted during retrieval queries when explainable mode is enabled.
Each event represents a provenance node created during query processing.
Attributes:
explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123)
explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval)
event_type: Type of provenance event (question, exploration, focus, synthesis)
event_type: Type of provenance event (question, exploration, focus, synthesis, etc.)
entity: Parsed ExplainEntity from inline triples (if available)
triples: Raw triples from the response (wire format dicts)
"""
explain_id: str
explain_graph: str = ""
event_type: str = "" # Derived from explain_id
entity: object = None # ExplainEntity (parsed from triples)
triples: list = dataclasses.field(default_factory=list) # Raw wire-format triple dicts
def __post_init__(self):
# Extract event type from explain_id
if "question" in self.explain_id:
self.event_type = "question"
elif "grounding" in self.explain_id:
self.event_type = "grounding"
elif "exploration" in self.explain_id:
self.event_type = "exploration"
elif "focus" in self.explain_id:
self.event_type = "focus"
elif "synthesis" in self.explain_id:
self.event_type = "synthesis"
elif "iteration" in self.explain_id:
self.event_type = "iteration"
elif "observation" in self.explain_id:
self.event_type = "observation"
elif "conclusion" in self.explain_id:
self.event_type = "conclusion"
elif "decomposition" in self.explain_id:
self.event_type = "decomposition"
elif "finding" in self.explain_id:
self.event_type = "finding"
elif "plan" in self.explain_id:
self.event_type = "plan"
elif "step-result" in self.explain_id:
self.event_type = "step-result"
elif "session" in self.explain_id:
self.event_type = "session"

View file

@ -1,6 +1,7 @@
from typing import Dict, Any, Tuple
from ...schema import AgentRequest, AgentResponse
from .base import MessageTranslator
from .primitives import TripleTranslator
class AgentRequestTranslator(MessageTranslator):
@ -49,10 +50,13 @@ class AgentRequestTranslator(MessageTranslator):
class AgentResponseTranslator(MessageTranslator):
"""Translator for AgentResponse schema objects"""
def __init__(self):
self.triple_translator = TripleTranslator()
def decode(self, data: Dict[str, Any]) -> AgentResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def encode(self, obj: AgentResponse) -> Dict[str, Any]:
result = {}
@ -75,6 +79,13 @@ class AgentResponseTranslator(MessageTranslator):
if explain_graph is not None:
result["explain_graph"] = explain_graph
# Include explain_triples for explain messages
explain_triples = getattr(obj, "explain_triples", [])
if explain_triples:
result["explain_triples"] = [
self.triple_translator.encode(t) for t in explain_triples
]
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "code": obj.error.code}

View file

@ -1,6 +1,7 @@
from typing import Dict, Any, Tuple
from ...schema import DocumentRagQuery, DocumentRagResponse, GraphRagQuery, GraphRagResponse
from .base import MessageTranslator
from .primitives import TripleTranslator
class DocumentRagRequestTranslator(MessageTranslator):
@ -28,6 +29,9 @@ class DocumentRagRequestTranslator(MessageTranslator):
class DocumentRagResponseTranslator(MessageTranslator):
"""Translator for DocumentRagResponse schema objects"""
def __init__(self):
self.triple_translator = TripleTranslator()
def decode(self, data: Dict[str, Any]) -> DocumentRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
@ -53,6 +57,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
if explain_graph is not None:
result["explain_graph"] = explain_graph
# Include explain_triples for explain messages
explain_triples = getattr(obj, "explain_triples", [])
if explain_triples:
result["explain_triples"] = [
self.triple_translator.encode(t) for t in explain_triples
]
# Include end_of_stream flag (LLM stream complete)
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
@ -107,6 +118,9 @@ class GraphRagRequestTranslator(MessageTranslator):
class GraphRagResponseTranslator(MessageTranslator):
"""Translator for GraphRagResponse schema objects"""
def __init__(self):
self.triple_translator = TripleTranslator()
def decode(self, data: Dict[str, Any]) -> GraphRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
@ -132,6 +146,13 @@ class GraphRagResponseTranslator(MessageTranslator):
if explain_graph is not None:
result["explain_graph"] = explain_graph
# Include explain_triples for explain messages
explain_triples = getattr(obj, "explain_triples", [])
if explain_triples:
result["explain_triples"] = [
self.triple_translator.encode(t) for t in explain_triples
]
# Include end_of_stream flag (LLM stream complete)
result["end_of_stream"] = getattr(obj, "end_of_stream", False)

View file

@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from typing import Optional
from ..core.primitives import Error
from ..core.primitives import Error, Triple
############################################################################
@ -57,8 +57,9 @@ class AgentResponse:
end_of_dialog: bool = False # Entire agent dialog is complete
# Explainability fields
explain_id: str | None = None # Provenance URI (announced as created)
explain_graph: str | None = None # Named graph where explain was stored
explain_id: str | None = None # Root URI for this explain step
explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval)
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
# Orchestration fields
message_id: str = "" # Unique ID for this response message

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
from ..core.primitives import Error, Term
from dataclasses import dataclass, field
from ..core.primitives import Error, Term, Triple
############################################################################
@ -24,8 +24,9 @@ class GraphRagResponse:
error: Error | None = None
response: str = ""
end_of_stream: bool = False # LLM response stream complete
explain_id: str | None = None # Single explain URI (announced as created)
explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval)
explain_id: str | None = None # Root URI for this explain step
explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval)
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete
@ -46,7 +47,8 @@ class DocumentRagResponse:
error: Error | None = None
response: str | None = ""
end_of_stream: bool = False # LLM response stream complete
explain_id: str | None = None # Single explain URI (announced as created)
explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval)
explain_id: str | None = None # Root URI for this explain step
explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval)
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete