Add unified explainability support and librarian storage for (#693)

Add unified explainability support and librarian storage for all retrieval engines

Implements consistent explainability/provenance tracking
across GraphRAG, DocumentRAG, and Agent retrieval
engines. All large content (answers, thoughts, observations)
is now stored in librarian rather than as inline literals in
the knowledge graph.

Explainability API:
- New explainability.py module with entity classes (Question,
  Exploration, Focus, Synthesis, Analysis, Conclusion) and
  ExplainabilityClient
- Quiescence-based eventual consistency handling for trace
  fetching
- Content fetching from librarian with retry logic

CLI updates:
- tg-invoke-graph-rag -x/--explainable flag returns
  explain_id
- tg-invoke-document-rag -x/--explainable flag returns
  explain_id
- tg-invoke-agent -x/--explainable flag returns explain_id
- tg-list-explain-traces uses new explainability API
- tg-show-explain-trace handles all three trace types

Agent provenance:
- Records session, iterations (think/act/observe), and conclusion
- Stores thoughts and observations in librarian with document
  references
- New predicates: tg:thoughtDocument, tg:observationDocument

DocumentRAG provenance:
- Records question, exploration (chunk retrieval), and synthesis
- Stores answers in librarian with document references

Schema changes:
- AgentResponse: added explain_id, explain_graph fields
- RetrievalResponse: added explain_id, explain_graph fields
- agent_iteration_triples: supports thought_document_id,
  observation_document_id

Update tests.
This commit is contained in:
cybermaggedon 2026-03-12 21:40:09 +00:00 committed by GitHub
parent aecf00f040
commit 35128ff019
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 2736 additions and 846 deletions

View file

@ -4,8 +4,19 @@ Uses the agent service to answer a question
import argparse
import os
import sys
import textwrap
from trustgraph.api import Api
from trustgraph.api import (
Api,
ExplainabilityClient,
ProvenanceEvent,
Question,
Analysis,
Conclusion,
AgentThought,
AgentObservation,
AgentAnswer,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -97,11 +108,148 @@ def output(text, prefix="> ", width=78):
)
print(out)
def question_explainable(
url, question_text, flow_id, user, collection,
state=None, group=None, verbose=False, token=None, debug=False
):
"""Execute agent with explainability - shows provenance events inline."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Track last chunk type for formatting
last_chunk_type = None
current_outputter = None
# Stream agent with explainability - process events as they arrive
for item in flow.agent_explain(
question=question_text,
user=user,
collection=collection,
state=state,
group=group,
):
if isinstance(item, AgentThought):
if last_chunk_type != "thought":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print() # Blank line between message types
if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f914 ")
current_outputter.__enter__()
last_chunk_type = "thought"
if current_outputter:
current_outputter.output(item.content)
if current_outputter.word_buffer:
print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = ""
elif isinstance(item, AgentObservation):
if last_chunk_type != "observation":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print()
if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
current_outputter.__enter__()
last_chunk_type = "observation"
if current_outputter:
current_outputter.output(item.content)
if current_outputter.word_buffer:
print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = ""
elif isinstance(item, AgentAnswer):
if last_chunk_type != "answer":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print()
last_chunk_type = "answer"
# Print answer content directly
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [session] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Analysis):
print(f"\n [iteration] {prov_id}", file=sys.stderr)
if entity.thought:
thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought
print(f" Thought: {thought_short}", file=sys.stderr)
if entity.action:
print(f" Action: {entity.action}", file=sys.stderr)
elif isinstance(entity, Conclusion):
print(f"\n [conclusion] {prov_id}", file=sys.stderr)
if entity.answer:
print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
# Close any remaining outputter
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
# Final newline if we ended with answer
if last_chunk_type == "answer":
print()
finally:
socket.close()
def question(
url, question, flow_id, user, collection,
plan=None, state=None, group=None, verbose=False, streaming=True,
token=None
token=None, explainable=False, debug=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
question_explainable(
url=url,
question_text=question,
flow_id=flow_id,
user=user,
collection=collection,
state=state,
group=group,
verbose=verbose,
token=token,
debug=debug
)
return
if verbose:
output(wrap(question), "\U00002753 ")
@ -270,6 +418,18 @@ def main():
help=f'Disable streaming (use legacy mode)'
)
parser.add_argument(
'-x', '--explainable',
action='store_true',
help='Show provenance events: Session, Iterations, Conclusion (implies streaming)'
)
parser.add_argument(
'--debug',
action='store_true',
help='Show debug output for troubleshooting'
)
args = parser.parse_args()
try:
@ -286,6 +446,8 @@ def main():
verbose = args.verbose,
streaming = not args.no_streaming,
token = args.token,
explainable = args.explainable,
debug = args.debug,
)
except Exception as e:

View file

@ -4,7 +4,16 @@ Uses the DocumentRAG service to answer a question
import argparse
import os
from trustgraph.api import Api
import sys
from trustgraph.api import (
Api,
ExplainabilityClient,
RAGChunk,
ProvenanceEvent,
Question,
Exploration,
Synthesis,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -12,7 +21,90 @@ default_user = 'trustgraph'
default_collection = 'default'
default_doc_limit = 10
def question(url, flow_id, question, user, collection, doc_limit, streaming=True, token=None):
def question_explainable(
url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False
):
"""Execute document RAG with explainability - shows provenance events inline."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Stream DocumentRAG with explainability - process events as they arrive
for item in flow.document_rag_explain(
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
):
if isinstance(item, RAGChunk):
# Print response content
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [question] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.chunk_count:
print(f" Chunks retrieved: {entity.chunk_count}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
print() # Final newline
finally:
socket.close()
def question(
url, flow_id, question_text, user, collection, doc_limit,
streaming=True, token=None, explainable=False, debug=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
question_explainable(
url=url,
flow_id=flow_id,
question_text=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
token=token,
debug=debug
)
return
# Create API client
api = Api(url=url, token=token)
@ -24,7 +116,7 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
try:
response = flow.document_rag(
query=question,
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
@ -42,13 +134,14 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
# Use REST API for non-streaming
flow = api.flow().id(flow_id)
resp = flow.document_rag(
query=question,
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
)
print(resp)
def main():
parser = argparse.ArgumentParser(
@ -105,6 +198,18 @@ def main():
help='Disable streaming (use non-streaming mode)'
)
parser.add_argument(
'-x', '--explainable',
action='store_true',
help='Show provenance events: Question, Exploration, Synthesis (implies streaming)'
)
parser.add_argument(
'--debug',
action='store_true',
help='Show debug output for troubleshooting'
)
args = parser.parse_args()
try:
@ -112,12 +217,14 @@ def main():
question(
url=args.url,
flow_id=args.flow_id,
question=args.question,
question_text=args.question,
user=args.user,
collection=args.collection,
doc_limit=args.doc_limit,
streaming=not args.no_streaming,
token=args.token,
explainable=args.explainable,
debug=args.debug,
)
except Exception as e:

View file

@ -8,7 +8,16 @@ import os
import sys
import websockets
import asyncio
from trustgraph.api import Api
from trustgraph.api import (
Api,
ExplainabilityClient,
RAGChunk,
ProvenanceEvent,
Question,
Exploration,
Focus,
Synthesis,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -602,18 +611,111 @@ async def _question_explainable(
print() # Final newline
def _question_explainable_api(
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
):
"""Execute graph RAG with explainability using the new API classes."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Stream GraphRAG with explainability - process events as they arrive
for item in flow.graph_rag_explain(
query=question_text,
user=user,
collection=collection,
max_subgraph_size=max_subgraph_size,
max_subgraph_count=5,
max_entity_distance=max_path_length,
):
if isinstance(item, RAGChunk):
# Print response content
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [question] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.edge_count:
print(f" Edges explored: {entity.edge_count}", file=sys.stderr)
elif isinstance(entity, Focus):
print(f"\n [focus] {prov_id}", file=sys.stderr)
if entity.selected_edge_uris:
print(f" Focused on {len(entity.selected_edge_uris)} edge(s)", file=sys.stderr)
# Fetch full focus with edge details
focus_full = explain_client.fetch_focus_with_edges(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if focus_full and focus_full.edge_selections:
for edge_sel in focus_full.edge_selections:
if edge_sel.edge:
# Resolve labels for edge components
s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, user, collection
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reason: {r_short}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
print() # Final newline
finally:
socket.close()
def question(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, streaming=True, token=None,
explainable=False, debug=False
):
# Explainable mode uses direct websocket to capture provenance events
# Explainable mode uses the API to capture and process provenance events
if explainable:
asyncio.run(_question_explainable(
_question_explainable_api(
url=url,
flow_id=flow_id,
question=question,
question_text=question,
user=user,
collection=collection,
entity_limit=entity_limit,
@ -622,7 +724,7 @@ def question(
max_path_length=max_path_length,
token=token,
debug=debug
))
)
return
# Create API client

View file

@ -14,180 +14,17 @@ import json
import os
import sys
from tabulate import tabulate
from trustgraph.api import Api
from trustgraph.api import Api, ExplainabilityClient
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph'
default_collection = 'default'
# Predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_QUESTION = TG + "Question"
TG_ANALYSIS = TG + "Analysis"
TG_EXPLORATION = TG + "Exploration"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
# Retrieval graph
RETRIEVAL_GRAPH = "urn:graph:retrieval"
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
"""Query triples using the socket API."""
request = {
"user": user,
"collection": collection,
"limit": limit,
"streaming": False,
}
if s is not None:
request["s"] = {"t": "i", "i": s}
if p is not None:
request["p"] = {"t": "i", "i": p}
if o is not None:
if isinstance(o, str):
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
request["o"] = {"t": "i", "i": o}
else:
request["o"] = {"t": "l", "v": o}
elif isinstance(o, dict):
request["o"] = o
if g is not None:
request["g"] = g
triples = []
try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
if isinstance(response, dict):
triple_list = response.get("response", response.get("triples", []))
else:
triple_list = response
if not isinstance(triple_list, list):
triple_list = [triple_list] if triple_list else []
for t in triple_list:
s_val = extract_value(t.get("s", {}))
p_val = extract_value(t.get("p", {}))
o_val = extract_value(t.get("o", {}))
triples.append((s_val, p_val, o_val))
except Exception as e:
print(f"Error querying triples: {e}", file=sys.stderr)
return triples
def extract_value(term):
"""Extract value from a term dict."""
if not term:
return ""
t = term.get("t") or term.get("type")
if t == "i":
return term.get("i") or term.get("iri", "")
elif t == "l":
return term.get("v") or term.get("value", "")
elif t == "t":
# Quoted triple
tr = term.get("tr") or term.get("triple", {})
return {
"s": extract_value(tr.get("s", {})),
"p": extract_value(tr.get("p", {})),
"o": extract_value(tr.get("o", {})),
}
# Fallback for raw values
if "i" in term:
return term["i"]
if "v" in term:
return term["v"]
return str(term)
def get_timestamp(socket, flow_id, user, collection, question_id):
"""Get timestamp for a question."""
triples = query_triples(
socket, flow_id, user, collection,
s=question_id, p=PROV_STARTED_AT_TIME, g=RETRIEVAL_GRAPH
)
for s, p, o in triples:
return o
return ""
def get_session_type(socket, flow_id, user, collection, session_id):
"""
Get the type of session (Agent or GraphRAG).
Both have tg:Question type, so we distinguish by URI pattern
or by checking what's derived from it.
"""
# Fast path: check URI pattern
if session_id.startswith("urn:trustgraph:agent:"):
return "Agent"
if session_id.startswith("urn:trustgraph:question:"):
return "GraphRAG"
# Check what's derived from this entity
derived = query_triples(
socket, flow_id, user, collection,
p=PROV_WAS_DERIVED_FROM, o=session_id, g=RETRIEVAL_GRAPH
)
generated = query_triples(
socket, flow_id, user, collection,
p=PROV_WAS_GENERATED_BY, o=session_id, g=RETRIEVAL_GRAPH
)
for s, p, o in derived + generated:
child_types = query_triples(
socket, flow_id, user, collection,
s=s, p=RDF_TYPE, g=RETRIEVAL_GRAPH
)
for _, _, child_type in child_types:
if child_type == TG_ANALYSIS:
return "Agent"
if child_type == TG_EXPLORATION:
return "GraphRAG"
return "GraphRAG"
def list_sessions(socket, flow_id, user, collection, limit):
"""List all explainability sessions (GraphRAG and Agent) by finding questions."""
# Query for all triples with predicate = tg:query
triples = query_triples(
socket, flow_id, user, collection,
p=TG_QUERY, g=RETRIEVAL_GRAPH, limit=limit
)
sessions = []
for question_id, _, query_text in triples:
# Get timestamp if available
timestamp = get_timestamp(socket, flow_id, user, collection, question_id)
# Get session type (Agent or GraphRAG)
session_type = get_session_type(socket, flow_id, user, collection, question_id)
sessions.append({
"id": question_id,
"type": session_type,
"question": query_text,
"time": timestamp,
})
# Sort by timestamp (newest first) if available
sessions.sort(key=lambda x: x.get("time", ""), reverse=True)
return sessions
def truncate_text(text, max_len=60):
"""Truncate text to max length with ellipsis."""
if not text:
@ -277,16 +114,42 @@ def main():
try:
api = Api(args.api_url, token=args.token)
socket = api.socket()
flow = socket.flow(args.flow_id)
explain_client = ExplainabilityClient(flow)
try:
sessions = list_sessions(
socket=socket,
flow_id=args.flow_id,
# List all sessions using the API
questions = explain_client.list_sessions(
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
limit=args.limit,
)
# Convert to output format
sessions = []
for q in questions:
session_type = explain_client.detect_session_type(
q.uri,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection
)
# Map type names
type_display = {
"graphrag": "GraphRAG",
"docrag": "DocRAG",
"agent": "Agent",
}.get(session_type, session_type.title())
sessions.append({
"id": q.uri,
"type": type_display,
"question": q.query,
"time": q.timestamp,
})
if args.format == 'json':
print_json(sessions)
else:

View file

@ -291,42 +291,25 @@ def query_graph(
):
"""Query the triple store with pattern matching.
Uses the WebSocket API's raw streaming mode for efficient delivery of results.
Uses the API's triples_query_stream for efficient streaming delivery.
"""
socket = Api(url, token=token).socket()
# Build request dict directly (bypassing triples_query_stream's string conversion)
request = {
"user": user,
"collection": collection,
"limit": limit,
"streaming": True,
"batch-size": batch_size,
}
# Add term dicts for s/p/o (None means wildcard)
if subject is not None:
request["s"] = subject
if predicate is not None:
request["p"] = predicate
if obj is not None:
request["o"] = obj
if graph is not None:
request["g"] = graph
flow = socket.flow(flow_id)
all_triples = []
try:
# Use raw streaming mode - yields response dicts directly
for response in socket._send_request_sync(
"triples", flow_id, request, streaming_raw=True
# Use triples_query_stream - accepts Term dicts directly
for triples in flow.triples_query_stream(
s=subject,
p=predicate,
o=obj,
g=graph,
user=user,
collection=collection,
limit=limit,
batch_size=batch_size,
):
# Response may have triples in different locations depending on format
if isinstance(response, dict):
triples = response.get("response", response.get("triples", []))
else:
triples = response
if not isinstance(triples, list):
triples = [triples] if triples else []

View file

@ -18,228 +18,99 @@ import argparse
import json
import os
import sys
from trustgraph.api import Api
from trustgraph.api import (
Api,
ExplainabilityClient,
Question,
Exploration,
Focus,
Synthesis,
Analysis,
Conclusion,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph'
default_collection = 'default'
# Predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document"
TG_REIFIES = TG + "reifies"
# Explainability entity types
TG_QUESTION = TG + "Question"
TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion"
# Agent predicates
TG_THOUGHT = TG + "thought"
TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
# Graphs
RETRIEVAL_GRAPH = "urn:graph:retrieval"
SOURCE_GRAPH = "urn:graph:source"
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
"""Query triples using the socket API."""
request = {
"user": user,
"collection": collection,
"limit": limit,
"streaming": False,
}
if s is not None:
request["s"] = {"t": "i", "i": s}
if p is not None:
request["p"] = {"t": "i", "i": p}
if o is not None:
if isinstance(o, str):
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
request["o"] = {"t": "i", "i": o}
else:
request["o"] = {"t": "l", "v": o}
elif isinstance(o, dict):
request["o"] = o
if g is not None:
request["g"] = g
triples = []
try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
if isinstance(response, dict):
triple_list = response.get("response", response.get("triples", []))
else:
triple_list = response
if not isinstance(triple_list, list):
triple_list = [triple_list] if triple_list else []
for t in triple_list:
s_val = extract_value(t.get("s", {}))
p_val = extract_value(t.get("p", {}))
o_val = extract_value(t.get("o", {}))
triples.append((s_val, p_val, o_val))
except Exception as e:
print(f"Error querying triples: {e}", file=sys.stderr)
return triples
# Provenance predicates for edge tracing
TG = "https://trustgraph.ai/ns/"
TG_REIFIES = TG + "reifies"
PROV = "http://www.w3.org/ns/prov#"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
def extract_value(term):
"""Extract value from a term dict."""
if not term:
return ""
def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client):
"""
Trace an edge back to its source document via reification.
t = term.get("t") or term.get("type")
Args:
flow: SocketFlowInstance
user: User identifier
collection: Collection identifier
edge: Dict with s, p, o keys
label_cache: Dict for caching labels
explain_client: ExplainabilityClient for label resolution
if t == "i":
return term.get("i") or term.get("iri", "")
elif t == "l":
return term.get("v") or term.get("value", "")
elif t == "t":
# Quoted triple
tr = term.get("tr") or term.get("triple", {})
return {
"s": extract_value(tr.get("s", {})),
"p": extract_value(tr.get("p", {})),
"o": extract_value(tr.get("o", {})),
}
Returns:
List of provenance chains, each chain is list of {uri, label}
"""
edge_s = edge.get("s", "")
edge_p = edge.get("p", "")
edge_o = edge.get("o", "")
# Fallback for raw values
if "i" in term:
return term["i"]
if "v" in term:
return term["v"]
# Build quoted triple for lookup
def build_term(val):
if isinstance(val, str) and (val.startswith("http") or val.startswith("urn:")):
return {"t": "i", "i": val}
return {"t": "l", "v": str(val)}
return str(term)
def get_node_properties(socket, flow_id, user, collection, node_uri, graph=RETRIEVAL_GRAPH):
"""Get all properties of a node as a dict."""
triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=graph)
props = {}
for s, p, o in triples:
if p not in props:
props[p] = []
props[p].append(o)
return props
def find_by_predicate_object(socket, flow_id, user, collection, predicate, obj, graph=RETRIEVAL_GRAPH):
"""Find subjects where predicate = obj."""
triples = query_triples(socket, flow_id, user, collection, p=predicate, o=obj, g=graph)
return [s for s, p, o in triples]
def get_label(socket, flow_id, user, collection, uri, label_cache):
"""Get label for a URI, with caching."""
if not isinstance(uri, str) or not (uri.startswith("http://") or uri.startswith("https://") or uri.startswith("urn:")):
return uri
if uri in label_cache:
return label_cache[uri]
triples = query_triples(socket, flow_id, user, collection, s=uri, p=RDFS_LABEL)
for s, p, o in triples:
label_cache[uri] = o
return o
label_cache[uri] = uri
return uri
def get_document_content(api, user, doc_id, max_content):
"""Fetch document content from librarian API."""
try:
library = api.library()
content = library.get_document_content(user=user, id=doc_id)
# Try to decode as text
try:
text = content.decode('utf-8')
if len(text) > max_content:
return text[:max_content] + "... [truncated]"
return text
except UnicodeDecodeError:
return f"[Binary: {len(content)} bytes]"
except Exception as e:
return f"[Error fetching content: {e}]"
def trace_edge_provenance(socket, flow_id, user, collection, edge_s, edge_p, edge_o, label_cache):
"""Trace an edge back to its source document via reification."""
# Build the quoted triple for lookup
quoted_triple = {
"t": "t",
"tr": {
"s": {"t": "i", "i": edge_s} if isinstance(edge_s, str) and (edge_s.startswith("http") or edge_s.startswith("urn:")) else {"t": "l", "v": edge_s},
"p": {"t": "i", "i": edge_p},
"o": {"t": "i", "i": edge_o} if isinstance(edge_o, str) and (edge_o.startswith("http") or edge_o.startswith("urn:")) else {"t": "l", "v": edge_o},
"s": build_term(edge_s),
"p": build_term(edge_p),
"o": build_term(edge_o),
}
}
# Query: ?stmt tg:reifies <<edge>>
request = {
"user": user,
"collection": collection,
"limit": 10,
"streaming": False,
"p": {"t": "i", "i": TG_REIFIES},
"o": quoted_triple,
"g": SOURCE_GRAPH,
}
stmt_uris = []
try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
if isinstance(response, dict):
triple_list = response.get("response", response.get("triples", []))
else:
triple_list = response
if not isinstance(triple_list, list):
triple_list = [triple_list] if triple_list else []
for t in triple_list:
s_val = extract_value(t.get("s", {}))
if s_val:
stmt_uris.append(s_val)
results = flow.triples_query(
p=TG_REIFIES,
o=quoted_triple,
g=SOURCE_GRAPH,
user=user,
collection=collection,
limit=10
)
except Exception:
pass
return []
# For each statement, find wasDerivedFrom chain
# Extract statement URIs
stmt_uris = []
for t in results:
s_term = t.get("s", {})
s_val = s_term.get("i") or s_term.get("v", "")
if s_val:
stmt_uris.append(s_val)
# For each statement, trace wasDerivedFrom chain
provenance_chains = []
for stmt_uri in stmt_uris:
chain = trace_provenance_chain(socket, flow_id, user, collection, stmt_uri, label_cache)
chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client)
if chain:
provenance_chains.append(chain)
return provenance_chains
def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_cache, max_depth=10):
def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10):
"""Trace prov:wasDerivedFrom chain from start_uri to root."""
chain = []
current = start_uri
@ -248,17 +119,32 @@ def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_c
if not current:
break
label = get_label(socket, flow_id, user, collection, current, label_cache)
# Get label
if current in label_cache:
label = label_cache[current]
else:
label = explain_client.resolve_label(current, user, collection)
label_cache[current] = label
chain.append({"uri": current, "label": label})
# Get parent
triples = query_triples(
socket, flow_id, user, collection,
s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH
)
# Get parent via wasDerivedFrom
try:
results = flow.triples_query(
s=current,
p=PROV_WAS_DERIVED_FROM,
g=SOURCE_GRAPH,
user=user,
collection=collection,
limit=1
)
except Exception:
break
parent = None
for s, p, o in triples:
parent = o
for t in results:
o_term = t.get("o", {})
parent = o_term.get("i") or o_term.get("v", "")
break
if not parent or parent == current:
@ -276,331 +162,24 @@ def format_provenance_chain(chain):
return " -> ".join(labels)
def format_edge(edge, label_cache=None, socket=None, flow_id=None, user=None, collection=None):
"""Format a quoted triple edge for display."""
if not isinstance(edge, dict):
return str(edge)
def print_graphrag_text(trace, explain_client, flow, user, collection, show_provenance=False):
"""Print GraphRAG trace in text format."""
question = trace.get("question")
s = edge.get("s", "?")
p = edge.get("p", "?")
o = edge.get("o", "?")
# Get labels if available
if label_cache and socket:
s_label = get_label(socket, flow_id, user, collection, s, label_cache)
p_label = get_label(socket, flow_id, user, collection, p, label_cache)
o_label = get_label(socket, flow_id, user, collection, o, label_cache)
else:
# Shorten URIs for display
s_label = s.split("/")[-1] if "/" in str(s) else s
p_label = p.split("/")[-1] if "/" in str(p) else p
o_label = o.split("/")[-1] if "/" in str(o) else o
return f"({s_label}, {p_label}, {o_label})"
def detect_trace_type(socket, flow_id, user, collection, entity_id):
"""
Detect whether an entity is an agent Question or GraphRAG Question.
Both have rdf:type = tg:Question, so we distinguish by checking
what's derived from it:
- Agent: has tg:Analysis or tg:Conclusion derived
- GraphRAG: has tg:Exploration derived
Also checks URI pattern as fallback:
- urn:trustgraph:agent: -> agent
- urn:trustgraph:question: -> graphrag
Returns:
"agent" or "graphrag"
"""
# Check URI pattern first (fast path)
if entity_id.startswith("urn:trustgraph:agent:"):
return "agent"
if entity_id.startswith("urn:trustgraph:question:"):
return "graphrag"
# Check what's derived from this entity
derived = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, entity_id
)
# Also check wasGeneratedBy (GraphRAG exploration uses this)
generated = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_GENERATED_BY, entity_id
)
all_children = derived + generated
for child_id in all_children:
child_types = query_triples(
socket, flow_id, user, collection,
s=child_id, p=RDF_TYPE, g=RETRIEVAL_GRAPH
)
for s, p, o in child_types:
if o == TG_ANALYSIS or o == TG_CONCLUSION:
return "agent"
if o == TG_EXPLORATION:
return "graphrag"
# Default to graphrag
return "graphrag"
def build_agent_trace(socket, flow_id, user, collection, session_id, api=None, max_answer=500):
"""Build the full explainability trace for an agent session."""
trace = {
"session_id": session_id,
"type": "agent",
"question": None,
"time": None,
"iterations": [],
"final_answer": None,
}
# Get session metadata
props = get_node_properties(socket, flow_id, user, collection, session_id)
trace["question"] = props.get(TG_QUERY, [None])[0]
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
# Find all entities derived from this session (iterations and final)
# Start by looking for entities where prov:wasDerivedFrom = session_id
current_uri = session_id
iteration_num = 1
while True:
# Find entities derived from current
derived_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, current_uri
)
if not derived_ids:
break
derived_id = derived_ids[0]
derived_props = get_node_properties(socket, flow_id, user, collection, derived_id)
# Check type
types = derived_props.get(RDF_TYPE, [])
if TG_ANALYSIS in types:
iteration = {
"id": derived_id,
"iteration_num": iteration_num,
"thought": derived_props.get(TG_THOUGHT, [None])[0],
"action": derived_props.get(TG_ACTION, [None])[0],
"arguments": derived_props.get(TG_ARGUMENTS, [None])[0],
"observation": derived_props.get(TG_OBSERVATION, [None])[0],
}
trace["iterations"].append(iteration)
current_uri = derived_id
iteration_num += 1
elif TG_CONCLUSION in types:
answer = derived_props.get(TG_ANSWER, [None])[0]
if answer and len(answer) > max_answer:
answer = answer[:max_answer] + "... [truncated]"
trace["final_answer"] = {
"id": derived_id,
"answer": answer,
}
break
else:
# Unknown type, stop traversal
break
return trace
def print_agent_text(trace):
"""Print agent trace in text format."""
print(f"=== Agent Session: {trace['session_id']} ===")
print(f"=== GraphRAG Session: {question.uri if question else 'Unknown'} ===")
print()
if trace["question"]:
print(f"Question: {trace['question']}")
if trace["time"]:
print(f"Time: {trace['time']}")
print()
# Analysis steps
print("--- Analysis ---")
iterations = trace.get("iterations", [])
if iterations:
for iteration in iterations:
print(f"Analysis {iteration['iteration_num']}:")
print(f" Thought: {iteration.get('thought', 'N/A')}")
print(f" Action: {iteration.get('action', 'N/A')}")
args = iteration.get('arguments')
if args:
# Try to pretty-print JSON arguments
try:
import json
args_obj = json.loads(args)
args_str = json.dumps(args_obj, indent=4)
# Indent each line
args_lines = args_str.split('\n')
print(f" Arguments:")
for line in args_lines:
print(f" {line}")
except:
print(f" Arguments: {args}")
else:
print(f" Arguments: N/A")
obs = iteration.get('observation', 'N/A')
if obs and len(obs) > 200:
obs = obs[:200] + "... [truncated]"
print(f" Observation: {obs}")
print()
else:
print("No analysis steps recorded")
print()
# Conclusion
print("--- Conclusion ---")
final = trace.get("final_answer")
if final and final.get("answer"):
print("Answer:")
for line in final["answer"].split("\n"):
print(f" {line}")
else:
print("No conclusion recorded")
def print_agent_json(trace):
"""Print agent trace as JSON."""
print(json.dumps(trace, indent=2))
def build_trace(socket, flow_id, user, collection, question_id, api=None, show_provenance=False, max_answer=500):
"""Build the full explainability trace for a question."""
label_cache = {}
trace = {
"question_id": question_id,
"question": None,
"time": None,
"exploration": None,
"focus": None,
"synthesis": None,
}
# Get question metadata
props = get_node_properties(socket, flow_id, user, collection, question_id)
trace["question"] = props.get(TG_QUERY, [None])[0]
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
# Find exploration: ?exploration prov:wasGeneratedBy question_id
exploration_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_GENERATED_BY, question_id
)
if exploration_ids:
exploration_id = exploration_ids[0]
exploration_props = get_node_properties(socket, flow_id, user, collection, exploration_id)
trace["exploration"] = {
"id": exploration_id,
"edge_count": exploration_props.get(TG_EDGE_COUNT, [None])[0],
}
# Find focus: ?focus prov:wasDerivedFrom exploration_id
focus_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, exploration_id
)
if focus_ids:
focus_id = focus_ids[0]
focus_props = get_node_properties(socket, flow_id, user, collection, focus_id)
# Get selected edges
edge_selection_uris = focus_props.get(TG_SELECTED_EDGE, [])
selected_edges = []
for edge_sel_uri in edge_selection_uris:
edge_sel_props = get_node_properties(socket, flow_id, user, collection, edge_sel_uri)
edge = edge_sel_props.get(TG_EDGE, [None])[0]
reasoning = edge_sel_props.get(TG_REASONING, [None])[0]
edge_info = {
"edge": edge,
"reasoning": reasoning,
}
# Trace provenance if requested
if show_provenance and isinstance(edge, dict):
provenance = trace_edge_provenance(
socket, flow_id, user, collection,
edge.get("s", ""), edge.get("p", ""), edge.get("o", ""),
label_cache
)
edge_info["provenance"] = provenance
selected_edges.append(edge_info)
trace["focus"] = {
"id": focus_id,
"selected_edges": selected_edges,
}
# Find synthesis: ?synthesis prov:wasDerivedFrom focus_id
synthesis_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, focus_id
)
if synthesis_ids:
synthesis_id = synthesis_ids[0]
synthesis_props = get_node_properties(socket, flow_id, user, collection, synthesis_id)
# Get content directly or via document reference
content = synthesis_props.get(TG_CONTENT, [None])[0]
doc_id = synthesis_props.get(TG_DOCUMENT, [None])[0]
if not content and doc_id and api:
content = get_document_content(api, user, doc_id, max_answer)
elif content and len(content) > max_answer:
content = content[:max_answer] + "... [truncated]"
trace["synthesis"] = {
"id": synthesis_id,
"document_id": doc_id,
"answer": content,
}
# Store label cache for formatting
trace["_label_cache"] = label_cache
return trace
def print_text(trace, show_provenance=False):
"""Print trace in text format."""
label_cache = trace.get("_label_cache", {})
print(f"=== GraphRAG Session: {trace['question_id']} ===")
print()
if trace["question"]:
print(f"Question: {trace['question']}")
if trace["time"]:
print(f"Time: {trace['time']}")
if question:
print(f"Question: {question.query}")
if question.timestamp:
print(f"Time: {question.timestamp}")
print()
# Exploration
print("--- Exploration ---")
exploration = trace.get("exploration")
if exploration:
edge_count = exploration.get("edge_count", "?")
print(f"Retrieved {edge_count} edges from knowledge graph")
print(f"Retrieved {exploration.edge_count} edges from knowledge graph")
else:
print("No exploration data found")
print()
@ -609,24 +188,28 @@ def print_text(trace, show_provenance=False):
print("--- Focus (Edge Selection) ---")
focus = trace.get("focus")
if focus:
edges = focus.get("selected_edges", [])
edges = focus.edge_selections
print(f"Selected {len(edges)} edges:")
print()
for i, edge_info in enumerate(edges, 1):
edge = edge_info.get("edge")
reasoning = edge_info.get("reasoning")
label_cache = {}
if edge:
edge_str = format_edge(edge)
print(f" {i}. {edge_str}")
for i, edge_sel in enumerate(edges, 1):
if edge_sel.edge:
s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, user, collection
)
print(f" {i}. ({s_label}, {p_label}, {o_label})")
if reasoning:
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reasoning: {r_short}")
if show_provenance:
provenance = edge_info.get("provenance", [])
if show_provenance and edge_sel.edge:
provenance = trace_edge_provenance(
flow, user, collection, edge_sel.edge,
label_cache, explain_client
)
for chain in provenance:
chain_str = format_provenance_chain(chain)
if chain_str:
@ -641,11 +224,9 @@ def print_text(trace, show_provenance=False):
print("--- Synthesis ---")
synthesis = trace.get("synthesis")
if synthesis:
answer = synthesis.get("answer")
if answer:
if synthesis.content:
print("Answer:")
# Indent the answer
for line in answer.split("\n"):
for line in synthesis.content.split("\n"):
print(f" {line}")
else:
print("No answer content found")
@ -653,11 +234,173 @@ def print_text(trace, show_provenance=False):
print("No synthesis data found")
def print_json(trace):
"""Print trace as JSON."""
# Remove internal cache before printing
output = {k: v for k, v in trace.items() if not k.startswith("_")}
print(json.dumps(output, indent=2))
def print_docrag_text(trace):
"""Print DocRAG trace in text format."""
question = trace.get("question")
print(f"=== DocRAG Session: {question.uri if question else 'Unknown'} ===")
print()
if question:
print(f"Question: {question.query}")
if question.timestamp:
print(f"Time: {question.timestamp}")
print()
# Exploration
print("--- Exploration ---")
exploration = trace.get("exploration")
if exploration:
print(f"Retrieved {exploration.chunk_count} chunks from document store")
else:
print("No exploration data found")
print()
# Synthesis (no Focus step for DocRAG)
print("--- Synthesis ---")
synthesis = trace.get("synthesis")
if synthesis:
if synthesis.content:
print("Answer:")
for line in synthesis.content.split("\n"):
print(f" {line}")
else:
print("No answer content found")
else:
print("No synthesis data found")
def print_agent_text(trace):
"""Print Agent trace in text format."""
question = trace.get("question")
print(f"=== Agent Session: {question.uri if question else 'Unknown'} ===")
print()
if question:
print(f"Question: {question.query}")
if question.timestamp:
print(f"Time: {question.timestamp}")
print()
# Analysis steps
print("--- Analysis ---")
iterations = trace.get("iterations", [])
if iterations:
for i, analysis in enumerate(iterations, 1):
print(f"Analysis {i}:")
print(f" Thought: {analysis.thought or 'N/A'}")
print(f" Action: {analysis.action or 'N/A'}")
if analysis.arguments:
# Try to pretty-print JSON arguments
try:
args_obj = json.loads(analysis.arguments)
args_str = json.dumps(args_obj, indent=4)
print(f" Arguments:")
for line in args_str.split('\n'):
print(f" {line}")
except Exception:
print(f" Arguments: {analysis.arguments}")
else:
print(f" Arguments: N/A")
obs = analysis.observation or 'N/A'
if obs and len(obs) > 200:
obs = obs[:200] + "... [truncated]"
print(f" Observation: {obs}")
print()
else:
print("No analysis steps recorded")
print()
# Conclusion
print("--- Conclusion ---")
conclusion = trace.get("conclusion")
if conclusion and conclusion.answer:
print("Answer:")
for line in conclusion.answer.split("\n"):
print(f" {line}")
else:
print("No conclusion recorded")
def trace_to_dict(trace, trace_type):
"""Convert trace entities to JSON-serializable dict."""
if trace_type == "agent":
question = trace.get("question")
return {
"type": "agent",
"session_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"iterations": [
{
"id": a.uri,
"thought": a.thought,
"action": a.action,
"arguments": a.arguments,
"observation": a.observation,
}
for a in trace.get("iterations", [])
],
"conclusion": {
"id": trace["conclusion"].uri,
"answer": trace["conclusion"].answer,
} if trace.get("conclusion") else None,
}
elif trace_type == "docrag":
question = trace.get("question")
exploration = trace.get("exploration")
synthesis = trace.get("synthesis")
return {
"type": "docrag",
"question_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"exploration": {
"id": exploration.uri,
"chunk_count": exploration.chunk_count,
} if exploration else None,
"synthesis": {
"id": synthesis.uri,
"document_uri": synthesis.document_uri,
"answer": synthesis.content,
} if synthesis else None,
}
else:
# graphrag
question = trace.get("question")
exploration = trace.get("exploration")
focus = trace.get("focus")
synthesis = trace.get("synthesis")
return {
"type": "graphrag",
"question_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"exploration": {
"id": exploration.uri,
"edge_count": exploration.edge_count,
} if exploration else None,
"focus": {
"id": focus.uri,
"selected_edges": [
{
"edge": edge_sel.edge,
"reasoning": edge_sel.reasoning,
}
for edge_sel in focus.edge_selections
],
} if focus else None,
"synthesis": {
"id": synthesis.uri,
"document_uri": synthesis.document_uri,
"answer": synthesis.content,
} if synthesis else None,
}
def main():
@ -727,50 +470,69 @@ def main():
try:
api = Api(args.api_url, token=args.token)
socket = api.socket()
flow = socket.flow(args.flow_id)
explain_client = ExplainabilityClient(flow)
try:
# Detect trace type (agent vs graphrag)
trace_type = detect_trace_type(
socket=socket,
flow_id=args.flow_id,
# Detect trace type
trace_type = explain_client.detect_session_type(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
entity_id=args.question_id,
)
if trace_type == "agent":
# Build and print agent trace
trace = build_agent_trace(
socket=socket,
flow_id=args.flow_id,
# Fetch and display agent trace
trace = explain_client.fetch_agent_trace(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
session_id=args.question_id,
api=api,
max_answer=args.max_answer,
max_content=args.max_answer,
)
if args.format == 'json':
print_agent_json(trace)
print(json.dumps(trace_to_dict(trace, "agent"), indent=2))
else:
print_agent_text(trace)
else:
# Build and print GraphRAG trace (existing behavior)
trace = build_trace(
socket=socket,
flow_id=args.flow_id,
elif trace_type == "docrag":
# Fetch and display DocRAG trace
trace = explain_client.fetch_docrag_trace(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
question_id=args.question_id,
api=api,
show_provenance=args.show_provenance,
max_answer=args.max_answer,
max_content=args.max_answer,
)
if args.format == 'json':
print_json(trace)
print(json.dumps(trace_to_dict(trace, "docrag"), indent=2))
else:
print_text(trace, show_provenance=args.show_provenance)
print_docrag_text(trace)
else:
# Fetch and display GraphRAG trace
trace = explain_client.fetch_graphrag_trace(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
api=api,
max_content=args.max_answer,
)
if args.format == 'json':
print(json.dumps(trace_to_dict(trace, "graphrag"), indent=2))
else:
print_graphrag_text(
trace, explain_client, flow,
args.user, args.collection,
show_provenance=args.show_provenance
)
finally:
socket.close()