mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Deliver explainability triples inline in retrieval response stream (#763)
Provenance triples are now included directly in explain messages from GraphRAG, DocumentRAG, and Agent services, eliminating the need for follow-up knowledge graph queries to retrieve explainability details. Each explain message in the response stream now carries: - explain_id: root URI for this provenance step (unchanged) - explain_graph: named graph where triples are stored (unchanged) - explain_triples: the actual provenance triples for this step (new) Changes across the stack: - Schema: added explain_triples field to GraphRagResponse, DocumentRagResponse, and AgentResponse - Services: all explain message call sites pass triples through (graph_rag, document_rag, agent react, agent orchestrator) - Translators: encode explain_triples via TripleTranslator for gateway wire format - Python SDK: ProvenanceEvent now includes parsed ExplainEntity and raw triples; expanded event_type detection - CLI: invoke_graph_rag, invoke_agent, invoke_document_rag use inline entity when available, fall back to graph query - Tech specs updated Additional explainability test
This commit is contained in:
parent
2f8d6a3ffb
commit
ddd4bd7790
16 changed files with 521 additions and 49 deletions
|
|
@ -366,19 +366,13 @@ class SocketClient:
|
|||
# Handle GraphRAG/DocRAG message format with message_type
|
||||
if message_type == "explain":
|
||||
if include_provenance:
|
||||
return ProvenanceEvent(
|
||||
explain_id=resp.get("explain_id", ""),
|
||||
explain_graph=resp.get("explain_graph", "")
|
||||
)
|
||||
return self._build_provenance_event(resp)
|
||||
return None
|
||||
|
||||
# Handle Agent message format with chunk_type="explain"
|
||||
if chunk_type == "explain":
|
||||
if include_provenance:
|
||||
return ProvenanceEvent(
|
||||
explain_id=resp.get("explain_id", ""),
|
||||
explain_graph=resp.get("explain_graph", "")
|
||||
)
|
||||
return self._build_provenance_event(resp)
|
||||
return None
|
||||
|
||||
if chunk_type == "thought":
|
||||
|
|
@ -413,6 +407,42 @@ class SocketClient:
|
|||
error=None
|
||||
)
|
||||
|
||||
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
|
||||
"""Build a ProvenanceEvent from a response dict, parsing inline triples
|
||||
into an ExplainEntity if available."""
|
||||
explain_id = resp.get("explain_id", "")
|
||||
explain_graph = resp.get("explain_graph", "")
|
||||
raw_triples = resp.get("explain_triples", [])
|
||||
|
||||
entity = None
|
||||
if raw_triples:
|
||||
try:
|
||||
from .explainability import ExplainEntity
|
||||
# Convert wire-format triple dicts to (s, p, o) tuples
|
||||
parsed = []
|
||||
for t in raw_triples:
|
||||
s = t.get("s", {}).get("i", "") if t.get("s") else ""
|
||||
p = t.get("p", {}).get("i", "") if t.get("p") else ""
|
||||
o_term = t.get("o", {})
|
||||
if o_term:
|
||||
if o_term.get("t") == "i":
|
||||
o = o_term.get("i", "")
|
||||
else:
|
||||
o = o_term.get("v", "")
|
||||
else:
|
||||
o = ""
|
||||
parsed.append((s, p, o))
|
||||
entity = ExplainEntity.from_triples(explain_id, parsed)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ProvenanceEvent(
|
||||
explain_id=explain_id,
|
||||
explain_graph=explain_graph,
|
||||
entity=entity,
|
||||
triples=raw_triples,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the persistent WebSocket connection."""
|
||||
if self._loop and not self._loop.is_closed():
|
||||
|
|
|
|||
|
|
@ -213,25 +213,47 @@ class ProvenanceEvent:
|
|||
"""
|
||||
Provenance event for explainability.
|
||||
|
||||
Emitted during GraphRAG queries when explainable mode is enabled.
|
||||
Emitted during retrieval queries when explainable mode is enabled.
|
||||
Each event represents a provenance node created during query processing.
|
||||
|
||||
Attributes:
|
||||
explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123)
|
||||
explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval)
|
||||
event_type: Type of provenance event (question, exploration, focus, synthesis)
|
||||
event_type: Type of provenance event (question, exploration, focus, synthesis, etc.)
|
||||
entity: Parsed ExplainEntity from inline triples (if available)
|
||||
triples: Raw triples from the response (wire format dicts)
|
||||
"""
|
||||
explain_id: str
|
||||
explain_graph: str = ""
|
||||
event_type: str = "" # Derived from explain_id
|
||||
entity: object = None # ExplainEntity (parsed from triples)
|
||||
triples: list = dataclasses.field(default_factory=list) # Raw wire-format triple dicts
|
||||
|
||||
def __post_init__(self):
|
||||
# Extract event type from explain_id
|
||||
if "question" in self.explain_id:
|
||||
self.event_type = "question"
|
||||
elif "grounding" in self.explain_id:
|
||||
self.event_type = "grounding"
|
||||
elif "exploration" in self.explain_id:
|
||||
self.event_type = "exploration"
|
||||
elif "focus" in self.explain_id:
|
||||
self.event_type = "focus"
|
||||
elif "synthesis" in self.explain_id:
|
||||
self.event_type = "synthesis"
|
||||
elif "iteration" in self.explain_id:
|
||||
self.event_type = "iteration"
|
||||
elif "observation" in self.explain_id:
|
||||
self.event_type = "observation"
|
||||
elif "conclusion" in self.explain_id:
|
||||
self.event_type = "conclusion"
|
||||
elif "decomposition" in self.explain_id:
|
||||
self.event_type = "decomposition"
|
||||
elif "finding" in self.explain_id:
|
||||
self.event_type = "finding"
|
||||
elif "plan" in self.explain_id:
|
||||
self.event_type = "plan"
|
||||
elif "step-result" in self.explain_id:
|
||||
self.event_type = "step-result"
|
||||
elif "session" in self.explain_id:
|
||||
self.event_type = "session"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import AgentRequest, AgentResponse
|
||||
from .base import MessageTranslator
|
||||
from .primitives import TripleTranslator
|
||||
|
||||
|
||||
class AgentRequestTranslator(MessageTranslator):
|
||||
|
|
@ -49,10 +50,13 @@ class AgentRequestTranslator(MessageTranslator):
|
|||
|
||||
class AgentResponseTranslator(MessageTranslator):
|
||||
"""Translator for AgentResponse schema objects"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.triple_translator = TripleTranslator()
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> AgentResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
||||
def encode(self, obj: AgentResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
|
|
@ -75,6 +79,13 @@ class AgentResponseTranslator(MessageTranslator):
|
|||
if explain_graph is not None:
|
||||
result["explain_graph"] = explain_graph
|
||||
|
||||
# Include explain_triples for explain messages
|
||||
explain_triples = getattr(obj, "explain_triples", [])
|
||||
if explain_triples:
|
||||
result["explain_triples"] = [
|
||||
self.triple_translator.encode(t) for t in explain_triples
|
||||
]
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "code": obj.error.code}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import DocumentRagQuery, DocumentRagResponse, GraphRagQuery, GraphRagResponse
|
||||
from .base import MessageTranslator
|
||||
from .primitives import TripleTranslator
|
||||
|
||||
|
||||
class DocumentRagRequestTranslator(MessageTranslator):
|
||||
|
|
@ -28,6 +29,9 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
class DocumentRagResponseTranslator(MessageTranslator):
|
||||
"""Translator for DocumentRagResponse schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.triple_translator = TripleTranslator()
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> DocumentRagResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
|
@ -53,6 +57,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
if explain_graph is not None:
|
||||
result["explain_graph"] = explain_graph
|
||||
|
||||
# Include explain_triples for explain messages
|
||||
explain_triples = getattr(obj, "explain_triples", [])
|
||||
if explain_triples:
|
||||
result["explain_triples"] = [
|
||||
self.triple_translator.encode(t) for t in explain_triples
|
||||
]
|
||||
|
||||
# Include end_of_stream flag (LLM stream complete)
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
|
|
@ -107,6 +118,9 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
class GraphRagResponseTranslator(MessageTranslator):
|
||||
"""Translator for GraphRagResponse schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.triple_translator = TripleTranslator()
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> GraphRagResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
|
@ -132,6 +146,13 @@ class GraphRagResponseTranslator(MessageTranslator):
|
|||
if explain_graph is not None:
|
||||
result["explain_graph"] = explain_graph
|
||||
|
||||
# Include explain_triples for explain messages
|
||||
explain_triples = getattr(obj, "explain_triples", [])
|
||||
if explain_triples:
|
||||
result["explain_triples"] = [
|
||||
self.triple_translator.encode(t) for t in explain_triples
|
||||
]
|
||||
|
||||
# Include end_of_stream flag (LLM stream complete)
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from ..core.primitives import Error
|
||||
from ..core.primitives import Error, Triple
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
@ -57,8 +57,9 @@ class AgentResponse:
|
|||
end_of_dialog: bool = False # Entire agent dialog is complete
|
||||
|
||||
# Explainability fields
|
||||
explain_id: str | None = None # Provenance URI (announced as created)
|
||||
explain_graph: str | None = None # Named graph where explain was stored
|
||||
explain_id: str | None = None # Root URI for this explain step
|
||||
explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval)
|
||||
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
|
||||
|
||||
# Orchestration fields
|
||||
message_id: str = "" # Unique ID for this response message
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from ..core.primitives import Error, Term
|
||||
from dataclasses import dataclass, field
|
||||
from ..core.primitives import Error, Term, Triple
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
@ -24,8 +24,9 @@ class GraphRagResponse:
|
|||
error: Error | None = None
|
||||
response: str = ""
|
||||
end_of_stream: bool = False # LLM response stream complete
|
||||
explain_id: str | None = None # Single explain URI (announced as created)
|
||||
explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval)
|
||||
explain_id: str | None = None # Root URI for this explain step
|
||||
explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval)
|
||||
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
|
||||
message_type: str = "" # "chunk" or "explain"
|
||||
end_of_session: bool = False # Entire session complete
|
||||
|
||||
|
|
@ -46,7 +47,8 @@ class DocumentRagResponse:
|
|||
error: Error | None = None
|
||||
response: str | None = ""
|
||||
end_of_stream: bool = False # LLM response stream complete
|
||||
explain_id: str | None = None # Single explain URI (announced as created)
|
||||
explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval)
|
||||
explain_id: str | None = None # Root URI for this explain step
|
||||
explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval)
|
||||
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
|
||||
message_type: str = "" # "chunk" or "explain"
|
||||
end_of_session: bool = False # Entire session complete
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue