mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-25 15:25:12 +02:00
Adding explainability to the ReACT agent (#689)
* Added tech spec
* Add provenance recording to React agent loop
Enables agent sessions to be traced and debugged using the same
explainability infrastructure as GraphRAG. Agent traces record:
- Session start with query and timestamp
- Each iteration's thought, action, arguments, and observation
- Final answer with derivation chain
Changes:
- Add session_id and collection fields to AgentRequest schema
- Add agent predicates (TG_THOUGHT, TG_ACTION, etc.) to namespaces
- Create agent provenance triple generators in provenance/agent.py
- Register explainability producer in agent service
- Emit provenance triples during agent execution
- Update CLI tools to detect and render agent traces alongside GraphRAG
* Updated explainability taxonomy:
GraphRAG: tg:Question → tg:Exploration → tg:Focus → tg:Synthesis
Agent: tg:Question → tg:Analysis(s) → tg:Conclusion
All entities also have their PROV-O type (prov:Activity or prov:Entity).
Updated commit message:
Add provenance recording to React agent loop
Enables agent sessions to be traced and debugged using the same
explainability infrastructure as GraphRAG.
Entity types follow human reasoning patterns:
- tg:Question - the user's query (shared with GraphRAG)
- tg:Analysis - each think/act/observe cycle
- tg:Conclusion - the final answer
Also adds explicit TG types to GraphRAG entities:
- tg:Question, tg:Exploration, tg:Focus, tg:Synthesis
All types retain their PROV-O base types (prov:Activity, prov:Entity).
Changes:
- Add session_id and collection fields to AgentRequest schema
- Add explainability entity types to namespaces.py
- Create agent provenance triple generators
- Register explainability producer in agent service
- Emit provenance triples during agent execution
- Update CLI tools to detect and render both trace types
* Document RAG explainability is now complete. Here's a summary of the
changes made:
Schema Changes:
- trustgraph-base/trustgraph/schema/services/retrieval.py: Added
explain_id and explain_graph fields to DocumentRagResponse
- trustgraph-base/trustgraph/messaging/translators/retrieval.py:
Updated translator to handle explainability fields
Provenance Changes:
- trustgraph-base/trustgraph/provenance/namespaces.py: Added
TG_CHUNK_COUNT and TG_SELECTED_CHUNK predicates
- trustgraph-base/trustgraph/provenance/uris.py: Added
docrag_question_uri, docrag_exploration_uri, docrag_synthesis_uri
generators
- trustgraph-base/trustgraph/provenance/triples.py: Added
docrag_question_triples, docrag_exploration_triples,
docrag_synthesis_triples builders
- trustgraph-base/trustgraph/provenance/__init__.py: Exported all
new Document RAG functions and predicates
Service Changes:
- trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py:
Added explainability callback support and triple emission at each
phase (Question → Exploration → Synthesis)
- trustgraph-flow/trustgraph/retrieval/document_rag/rag.py:
Registered explainability producer and wired up the callback
Documentation:
- docs/tech-specs/agent-explainability.md: Added Document RAG entity
types and provenance model documentation
Document RAG Provenance Model:
Question (urn:trustgraph:docrag:{uuid})
│
│ tg:query, prov:startedAtTime
│ rdf:type = prov:Activity, tg:Question
│
↓ prov:wasGeneratedBy
│
Exploration (urn:trustgraph:docrag:{uuid}/exploration)
│
│ tg:chunkCount, tg:selectedChunk (multiple)
│ rdf:type = prov:Entity, tg:Exploration
│
↓ prov:wasDerivedFrom
│
Synthesis (urn:trustgraph:docrag:{uuid}/synthesis)
│
│ tg:content = "The answer..."
│ rdf:type = prov:Entity, tg:Synthesis
* Specific subtype that makes the retrieval mechanism immediately
obvious:
System: GraphRAG
TG Types on Question: tg:Question, tg:GraphRagQuestion
URI Pattern: urn:trustgraph:question:{uuid}
────────────────────────────────────────
System: Document RAG
TG Types on Question: tg:Question, tg:DocRagQuestion
URI Pattern: urn:trustgraph:docrag:{uuid}
────────────────────────────────────────
System: Agent
TG Types on Question: tg:Question, tg:AgentQuestion
URI Pattern: urn:trustgraph:agent:{uuid}
Files modified:
- trustgraph-base/trustgraph/provenance/namespaces.py - Added
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION
- trustgraph-base/trustgraph/provenance/triples.py - Added subtype to
question_triples and docrag_question_triples
- trustgraph-base/trustgraph/provenance/agent.py - Added subtype to
agent_session_triples
- trustgraph-base/trustgraph/provenance/__init__.py - Exported new types
- docs/tech-specs/agent-explainability.md - Documented the subtypes
This allows:
- Query all questions: ?q rdf:type tg:Question
- Query only GraphRAG: ?q rdf:type tg:GraphRagQuestion
- Query only Document RAG: ?q rdf:type tg:DocRagQuestion
- Query only Agent: ?q rdf:type tg:AgentQuestion
* Fixed tests
This commit is contained in:
parent
a53ed41da2
commit
312174eb88
17 changed files with 1269 additions and 44 deletions
|
|
@ -7,6 +7,8 @@ import re
|
|||
import sys
|
||||
import functools
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -14,8 +16,22 @@ logger = logging.getLogger(__name__)
|
|||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||
from ... base import ProducerSpec
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from ... schema import Triples, Metadata
|
||||
|
||||
# Provenance imports for agent explainability
|
||||
from trustgraph.provenance import (
|
||||
agent_session_uri,
|
||||
agent_iteration_uri,
|
||||
agent_final_uri,
|
||||
agent_session_triples,
|
||||
agent_iteration_triples,
|
||||
agent_final_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl, ToolServiceImpl
|
||||
from . agent_manager import AgentManager
|
||||
|
|
@ -105,6 +121,14 @@ class Processor(AgentService):
|
|||
)
|
||||
)
|
||||
|
||||
# Explainability producer for agent provenance triples
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
logger.info(f"Loading configuration version {version}")
|
||||
|
|
@ -285,6 +309,10 @@ class Processor(AgentService):
|
|||
# Check if streaming is enabled
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
|
||||
# Generate or retrieve session ID for provenance tracking
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
collection = getattr(request, 'collection', 'default')
|
||||
|
||||
if request.history:
|
||||
history = [
|
||||
Action(
|
||||
|
|
@ -298,6 +326,27 @@ class Processor(AgentService):
|
|||
else:
|
||||
history = []
|
||||
|
||||
# Calculate iteration number (1-based)
|
||||
iteration_num = len(history) + 1
|
||||
session_uri = agent_session_uri(session_id)
|
||||
|
||||
# On first iteration, emit session triples
|
||||
if iteration_num == 1:
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
triples = set_graph(
|
||||
agent_session_triples(session_uri, request.question, timestamp),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=session_uri,
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
),
|
||||
triples=triples,
|
||||
))
|
||||
logger.debug(f"Emitted session triples for {session_uri}")
|
||||
|
||||
logger.info(f"Question: {request.question}")
|
||||
|
||||
if len(history) >= self.max_iterations:
|
||||
|
|
@ -447,6 +496,28 @@ class Processor(AgentService):
|
|||
else:
|
||||
f = json.dumps(act.final)
|
||||
|
||||
# Emit final answer provenance triples
|
||||
final_uri = agent_final_uri(session_id)
|
||||
# Parent is last iteration, or session if no iterations
|
||||
if iteration_num > 1:
|
||||
parent_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
else:
|
||||
parent_uri = session_uri
|
||||
|
||||
final_triples = set_graph(
|
||||
agent_final_triples(final_uri, parent_uri, f),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=final_uri,
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
),
|
||||
triples=final_triples,
|
||||
))
|
||||
logger.debug(f"Emitted final triples for {final_uri}")
|
||||
|
||||
if streaming:
|
||||
# Streaming format - send end-of-dialog marker
|
||||
# Answer chunks were already sent via answer() callback during parsing
|
||||
|
|
@ -479,8 +550,37 @@ class Processor(AgentService):
|
|||
|
||||
logger.debug("Send next...")
|
||||
|
||||
# Emit iteration provenance triples
|
||||
iteration_uri = agent_iteration_uri(session_id, iteration_num)
|
||||
# Parent is previous iteration, or session if this is first iteration
|
||||
if iteration_num > 1:
|
||||
parent_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
else:
|
||||
parent_uri = session_uri
|
||||
|
||||
iter_triples = set_graph(
|
||||
agent_iteration_triples(
|
||||
iteration_uri,
|
||||
parent_uri,
|
||||
act.thought,
|
||||
act.name,
|
||||
act.arguments,
|
||||
act.observation,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=iteration_uri,
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
),
|
||||
triples=iter_triples,
|
||||
))
|
||||
logger.debug(f"Emitted iteration triples for {iteration_uri}")
|
||||
|
||||
history.append(act)
|
||||
|
||||
|
||||
# Handle state transitions if tool execution was successful
|
||||
next_state = request.state
|
||||
if act.name in filtered_tools:
|
||||
|
|
@ -501,7 +601,9 @@ class Processor(AgentService):
|
|||
for h in history
|
||||
],
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id, # Pass session_id for provenance continuity
|
||||
)
|
||||
|
||||
await next(r)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,20 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Provenance imports
|
||||
from trustgraph.provenance import (
|
||||
docrag_question_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_synthesis_uri,
|
||||
docrag_question_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -33,7 +47,14 @@ class Query:
|
|||
return qembeds[0] if qembeds else []
|
||||
|
||||
async def get_docs(self, query):
|
||||
"""
|
||||
Get documents (chunks) matching the query.
|
||||
|
||||
Returns:
|
||||
tuple: (docs, chunk_ids) where:
|
||||
- docs: list of document content strings
|
||||
- chunk_ids: list of chunk IDs that were successfully fetched
|
||||
"""
|
||||
vectors = await self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -50,11 +71,13 @@ class Query:
|
|||
|
||||
# Fetch chunk content from Garage
|
||||
docs = []
|
||||
chunk_ids = []
|
||||
for match in chunk_matches:
|
||||
if match.chunk_id:
|
||||
try:
|
||||
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
|
||||
docs.append(content)
|
||||
chunk_ids.append(match.chunk_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
|
||||
|
||||
|
|
@ -63,7 +86,7 @@ class Query:
|
|||
for doc in docs:
|
||||
logger.debug(f" {doc[:100]}...")
|
||||
|
||||
return docs
|
||||
return docs, chunk_ids
|
||||
|
||||
class DocumentRag:
|
||||
|
||||
|
|
@ -86,17 +109,56 @@ class DocumentRag:
|
|||
async def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
doc_limit=20, streaming=False, chunk_callback=None,
|
||||
explain_callback=None,
|
||||
):
|
||||
"""
|
||||
Execute a Document RAG query with optional explainability tracking.
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
doc_limit: Max chunks to retrieve
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for explainability
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
"""
|
||||
if self.verbose:
|
||||
logger.debug("Constructing prompt...")
|
||||
|
||||
# Generate explainability URIs upfront
|
||||
session_id = str(uuid.uuid4())
|
||||
q_uri = docrag_question_uri(session_id)
|
||||
exp_uri = docrag_exploration_uri(session_id)
|
||||
syn_uri = docrag_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# Emit question explainability immediately
|
||||
if explain_callback:
|
||||
q_triples = set_graph(
|
||||
docrag_question_triples(q_uri, query, timestamp),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(q_triples, q_uri)
|
||||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
docs = await q.get_docs(query)
|
||||
docs, chunk_ids = await q.get_docs(query)
|
||||
|
||||
# Emit exploration explainability after chunks retrieved
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
docrag_exploration_triples(exp_uri, q_uri, len(chunk_ids), chunk_ids),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
|
|
@ -104,12 +166,21 @@ class DocumentRag:
|
|||
logger.debug(f"Query: {query}")
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
||||
async def accumulating_callback(chunk, end_of_stream):
|
||||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs,
|
||||
streaming=True,
|
||||
chunk_callback=chunk_callback
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
|
|
@ -119,5 +190,17 @@ class DocumentRag:
|
|||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
||||
# Emit synthesis explainability after answer generated
|
||||
if explain_callback:
|
||||
answer_text = resp if resp else ""
|
||||
syn_triples = set_graph(
|
||||
docrag_synthesis_triples(syn_uri, exp_uri, answer_text),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(syn_triples, syn_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Emitted explain for session {session_id}")
|
||||
|
||||
return resp
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import logging
|
|||
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
||||
from ... schema import LibrarianRequest, LibrarianResponse
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... schema import Triples, Metadata
|
||||
from ... provenance import GRAPH_RETRIEVAL
|
||||
from . document_rag import DocumentRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
|
|
@ -78,6 +80,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for fetching chunk content from Garage
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
|
|
@ -194,6 +203,29 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
# Real-time explainability callback - emits triples and IDs as they're generated
|
||||
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
|
||||
async def send_explainability(triples, explain_id):
|
||||
# Send triples to explainability queue - stores in same collection with named graph
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=explain_id,
|
||||
user=v.user,
|
||||
collection=v.collection, # Store in user's collection
|
||||
),
|
||||
triples=triples,
|
||||
))
|
||||
|
||||
# Send explain ID and graph to response queue
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response=None,
|
||||
explain_id=explain_id,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
|
|
@ -217,6 +249,7 @@ class Processor(FlowProcessor):
|
|||
doc_limit=doc_limit,
|
||||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
explain_callback=send_explainability,
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
|
|
@ -224,7 +257,8 @@ class Processor(FlowProcessor):
|
|||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit
|
||||
doc_limit=doc_limit,
|
||||
explain_callback=send_explainability,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue