mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
Add explainability CLI tools (#688)
Add explainability CLI tools for debugging provenance data - tg-show-document-hierarchy: Display document → page → chunk → edge hierarchy by traversing prov:wasDerivedFrom relationships - tg-list-explain-traces: List all GraphRAG sessions with questions and timestamps from the retrieval graph - tg-show-explain-trace: Show full explainability cascade for a GraphRAG session (question → exploration → focus → synthesis) These tools query the source and retrieval graphs to help debug and explore provenance/explainability data stored during document processing and GraphRAG queries.
This commit is contained in:
parent
fda508fdae
commit
a53ed41da2
5 changed files with 1469 additions and 0 deletions
257
trustgraph-cli/trustgraph/cli/list_explain_traces.py
Normal file
257
trustgraph-cli/trustgraph/cli/list_explain_traces.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""
|
||||
List all GraphRAG sessions (questions) in a collection.
|
||||
|
||||
Queries for all questions stored in the retrieval graph and displays them
|
||||
with their session IDs and timestamps.
|
||||
|
||||
Examples:
|
||||
tg-list-explain-traces -U trustgraph -C default
|
||||
tg-list-explain-traces --limit 20 --format json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from tabulate import tabulate
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
# Predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
|
||||
# Retrieval graph
|
||||
RETRIEVAL_GRAPH = "urn:graph:retrieval"
|
||||
|
||||
|
||||
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
|
||||
"""Query triples using the socket API."""
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
if s is not None:
|
||||
request["s"] = {"t": "i", "i": s}
|
||||
if p is not None:
|
||||
request["p"] = {"t": "i", "i": p}
|
||||
if o is not None:
|
||||
if isinstance(o, str):
|
||||
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
|
||||
request["o"] = {"t": "i", "i": o}
|
||||
else:
|
||||
request["o"] = {"t": "l", "v": o}
|
||||
elif isinstance(o, dict):
|
||||
request["o"] = o
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
|
||||
triples = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
p_val = extract_value(t.get("p", {}))
|
||||
o_val = extract_value(t.get("o", {}))
|
||||
triples.append((s_val, p_val, o_val))
|
||||
except Exception as e:
|
||||
print(f"Error querying triples: {e}", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def extract_value(term):
|
||||
"""Extract value from a term dict."""
|
||||
if not term:
|
||||
return ""
|
||||
|
||||
t = term.get("t") or term.get("type")
|
||||
|
||||
if t == "i":
|
||||
return term.get("i") or term.get("iri", "")
|
||||
elif t == "l":
|
||||
return term.get("v") or term.get("value", "")
|
||||
elif t == "t":
|
||||
# Quoted triple
|
||||
tr = term.get("tr") or term.get("triple", {})
|
||||
return {
|
||||
"s": extract_value(tr.get("s", {})),
|
||||
"p": extract_value(tr.get("p", {})),
|
||||
"o": extract_value(tr.get("o", {})),
|
||||
}
|
||||
|
||||
# Fallback for raw values
|
||||
if "i" in term:
|
||||
return term["i"]
|
||||
if "v" in term:
|
||||
return term["v"]
|
||||
|
||||
return str(term)
|
||||
|
||||
|
||||
def get_timestamp(socket, flow_id, user, collection, question_id):
|
||||
"""Get timestamp for a question."""
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=question_id, p=PROV_STARTED_AT_TIME, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
for s, p, o in triples:
|
||||
return o
|
||||
return ""
|
||||
|
||||
|
||||
def list_sessions(socket, flow_id, user, collection, limit):
|
||||
"""List all GraphRAG sessions by finding questions."""
|
||||
# Query for all triples with predicate = tg:query
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=TG_QUERY, g=RETRIEVAL_GRAPH, limit=limit
|
||||
)
|
||||
|
||||
sessions = []
|
||||
for question_id, _, query_text in triples:
|
||||
# Get timestamp if available
|
||||
timestamp = get_timestamp(socket, flow_id, user, collection, question_id)
|
||||
|
||||
sessions.append({
|
||||
"id": question_id,
|
||||
"question": query_text,
|
||||
"time": timestamp,
|
||||
})
|
||||
|
||||
# Sort by timestamp (newest first) if available
|
||||
sessions.sort(key=lambda x: x.get("time", ""), reverse=True)
|
||||
|
||||
return sessions
|
||||
|
||||
|
||||
def truncate_text(text, max_len=60):
|
||||
"""Truncate text to max length with ellipsis."""
|
||||
if not text:
|
||||
return ""
|
||||
if len(text) <= max_len:
|
||||
return text
|
||||
return text[:max_len - 3] + "..."
|
||||
|
||||
|
||||
def print_table(sessions):
|
||||
"""Print sessions as a table."""
|
||||
if not sessions:
|
||||
print("No GraphRAG sessions found.")
|
||||
return
|
||||
|
||||
rows = []
|
||||
for session in sessions:
|
||||
rows.append([
|
||||
session["id"],
|
||||
truncate_text(session["question"], 50),
|
||||
session.get("time", "")
|
||||
])
|
||||
|
||||
headers = ["Session ID", "Question", "Time"]
|
||||
print(tabulate(rows, headers=headers, tablefmt="simple"))
|
||||
|
||||
|
||||
def print_json(sessions):
|
||||
"""Print sessions as JSON."""
|
||||
print(json.dumps(sessions, indent=2))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-list-explain-traces',
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--api-url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Auth token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-U', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-C', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection (default: {default_collection})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default='default',
|
||||
help='Flow ID (default: default)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--limit',
|
||||
type=int,
|
||||
default=50,
|
||||
help='Max results (default: 50)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--format',
|
||||
choices=['table', 'json'],
|
||||
default='table',
|
||||
help='Output format: table (default), json',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
api = Api(args.api_url, token=args.token)
|
||||
socket = api.socket()
|
||||
|
||||
try:
|
||||
sessions = list_sessions(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
if args.format == 'json':
|
||||
print_json(sessions)
|
||||
else:
|
||||
print_table(sessions)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
431
trustgraph-cli/trustgraph/cli/show_document_hierarchy.py
Normal file
431
trustgraph-cli/trustgraph/cli/show_document_hierarchy.py
Normal file
|
|
@ -0,0 +1,431 @@
|
|||
"""
|
||||
Show document hierarchy: Document -> Pages -> Chunks -> Edges.
|
||||
|
||||
Given a document ID, traverses and displays all derived entities
|
||||
(pages, chunks, extracted edges) using prov:wasDerivedFrom relationships.
|
||||
|
||||
Examples:
|
||||
tg-show-document-hierarchy -U trustgraph -C default "urn:trustgraph:doc:abc123"
|
||||
tg-show-document-hierarchy --show-content --max-content 500 "urn:trustgraph:doc:abc123"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
# Predicates
|
||||
PROV_WAS_DERIVED_FROM = "http://www.w3.org/ns/prov#wasDerivedFrom"
|
||||
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_REIFIES = TG + "reifies"
|
||||
DC_TITLE = "http://purl.org/dc/terms/title"
|
||||
DC_FORMAT = "http://purl.org/dc/terms/format"
|
||||
|
||||
# Source graph
|
||||
SOURCE_GRAPH = "urn:graph:source"
|
||||
|
||||
|
||||
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
|
||||
"""Query triples using the socket API."""
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
if s is not None:
|
||||
request["s"] = {"t": "i", "i": s}
|
||||
if p is not None:
|
||||
request["p"] = {"t": "i", "i": p}
|
||||
if o is not None:
|
||||
if isinstance(o, str):
|
||||
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
|
||||
request["o"] = {"t": "i", "i": o}
|
||||
else:
|
||||
request["o"] = {"t": "l", "v": o}
|
||||
elif isinstance(o, dict):
|
||||
request["o"] = o
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
|
||||
triples = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
p_val = extract_value(t.get("p", {}))
|
||||
o_val = extract_value(t.get("o", {}))
|
||||
triples.append((s_val, p_val, o_val))
|
||||
except Exception as e:
|
||||
print(f"Error querying triples: {e}", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def extract_value(term):
|
||||
"""Extract value from a term dict."""
|
||||
if not term:
|
||||
return ""
|
||||
|
||||
t = term.get("t") or term.get("type")
|
||||
|
||||
if t == "i":
|
||||
return term.get("i") or term.get("iri", "")
|
||||
elif t == "l":
|
||||
return term.get("v") or term.get("value", "")
|
||||
elif t == "t":
|
||||
# Quoted triple
|
||||
tr = term.get("tr") or term.get("triple", {})
|
||||
return {
|
||||
"s": extract_value(tr.get("s", {})),
|
||||
"p": extract_value(tr.get("p", {})),
|
||||
"o": extract_value(tr.get("o", {})),
|
||||
}
|
||||
|
||||
# Fallback for raw values
|
||||
if "i" in term:
|
||||
return term["i"]
|
||||
if "v" in term:
|
||||
return term["v"]
|
||||
|
||||
return str(term)
|
||||
|
||||
|
||||
def get_node_metadata(socket, flow_id, user, collection, node_uri):
|
||||
"""Get metadata for a node (label, type, title, format)."""
|
||||
triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=SOURCE_GRAPH)
|
||||
|
||||
metadata = {"uri": node_uri}
|
||||
for s, p, o in triples:
|
||||
if p == RDFS_LABEL:
|
||||
metadata["label"] = o
|
||||
elif p == RDF_TYPE:
|
||||
metadata["type"] = o
|
||||
elif p == DC_TITLE:
|
||||
metadata["title"] = o
|
||||
elif p == DC_FORMAT:
|
||||
metadata["format"] = o
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def get_children(socket, flow_id, user, collection, parent_uri):
|
||||
"""Get children of a node via prov:wasDerivedFrom."""
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=PROV_WAS_DERIVED_FROM, o=parent_uri, g=SOURCE_GRAPH
|
||||
)
|
||||
return [s for s, p, o in triples]
|
||||
|
||||
|
||||
def get_edges_from_chunk(socket, flow_id, user, collection, chunk_uri):
|
||||
"""Get edges that were derived from a chunk (via tg:reifies)."""
|
||||
# Query for triples where: ?stmt prov:wasDerivedFrom chunk_uri
|
||||
# Then get the tg:reifies value
|
||||
derived_triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=PROV_WAS_DERIVED_FROM, o=chunk_uri, g=SOURCE_GRAPH
|
||||
)
|
||||
|
||||
edges = []
|
||||
for stmt_uri, _, _ in derived_triples:
|
||||
# Get what this statement reifies
|
||||
reifies_triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=stmt_uri, p=TG_REIFIES, g=SOURCE_GRAPH
|
||||
)
|
||||
for _, _, edge in reifies_triples:
|
||||
if isinstance(edge, dict):
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def get_document_content(api, user, doc_id, max_content):
|
||||
"""Fetch document content from librarian API."""
|
||||
try:
|
||||
library = api.library()
|
||||
content = library.get_document_content(user=user, id=doc_id)
|
||||
|
||||
# Try to decode as text
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
if len(text) > max_content:
|
||||
return text[:max_content] + "... [truncated]"
|
||||
return text
|
||||
except UnicodeDecodeError:
|
||||
return f"[Binary: {len(content)} bytes]"
|
||||
except Exception as e:
|
||||
return f"[Error fetching content: {e}]"
|
||||
|
||||
|
||||
def classify_uri(uri):
|
||||
"""Classify a URI as document, page, or chunk based on patterns."""
|
||||
if not isinstance(uri, str):
|
||||
return "unknown"
|
||||
|
||||
# Common patterns in trustgraph URIs
|
||||
if "/c" in uri and uri.split("/c")[-1].isdigit():
|
||||
return "chunk"
|
||||
if "/p" in uri and any(uri.split("/p")[-1].replace("/", "").isdigit() for _ in [1]):
|
||||
# Check for page pattern like /p1 or /p1/
|
||||
parts = uri.split("/p")
|
||||
if len(parts) > 1:
|
||||
remainder = parts[-1].split("/")[0]
|
||||
if remainder.isdigit():
|
||||
return "page"
|
||||
|
||||
if "chunk" in uri.lower():
|
||||
return "chunk"
|
||||
if "page" in uri.lower():
|
||||
return "page"
|
||||
if "doc" in uri.lower():
|
||||
return "document"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_content=False, max_content=200, visited=None):
|
||||
"""Build document hierarchy tree recursively."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if root_uri in visited:
|
||||
return None
|
||||
visited.add(root_uri)
|
||||
|
||||
metadata = get_node_metadata(socket, flow_id, user, collection, root_uri)
|
||||
node_type = classify_uri(root_uri)
|
||||
|
||||
node = {
|
||||
"uri": root_uri,
|
||||
"type": node_type,
|
||||
"metadata": metadata,
|
||||
"children": [],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
# Fetch content if requested
|
||||
if show_content and api:
|
||||
content = get_document_content(api, user, root_uri, max_content)
|
||||
if content:
|
||||
node["content"] = content
|
||||
|
||||
# Get children
|
||||
children_uris = get_children(socket, flow_id, user, collection, root_uri)
|
||||
|
||||
for child_uri in children_uris:
|
||||
child_type = classify_uri(child_uri)
|
||||
|
||||
# Recursively build hierarchy for pages and chunks
|
||||
if child_type in ("page", "chunk", "unknown"):
|
||||
child_node = build_hierarchy(
|
||||
socket, flow_id, user, collection, child_uri,
|
||||
api=api, show_content=show_content, max_content=max_content,
|
||||
visited=visited
|
||||
)
|
||||
if child_node:
|
||||
node["children"].append(child_node)
|
||||
|
||||
# Get edges for chunks
|
||||
if node_type == "chunk":
|
||||
edges = get_edges_from_chunk(socket, flow_id, user, collection, root_uri)
|
||||
node["edges"] = edges
|
||||
|
||||
# Sort children by URI for consistent output
|
||||
node["children"].sort(key=lambda x: x.get("uri", ""))
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def format_edge(edge):
|
||||
"""Format an edge (quoted triple) for display."""
|
||||
if isinstance(edge, dict):
|
||||
s = edge.get("s", "?")
|
||||
p = edge.get("p", "?")
|
||||
o = edge.get("o", "?")
|
||||
|
||||
# Shorten URIs for display
|
||||
s_short = s.split("/")[-1] if "/" in str(s) else s
|
||||
p_short = p.split("/")[-1] if "/" in str(p) else p
|
||||
o_short = o.split("/")[-1] if "/" in str(o) else o
|
||||
|
||||
return f"({s_short}, {p_short}, {o_short})"
|
||||
return str(edge)
|
||||
|
||||
|
||||
def print_tree(node, prefix="", is_last=True, show_content=False):
|
||||
"""Print node as indented tree."""
|
||||
connector = "└── " if is_last else "├── "
|
||||
continuation = " " if is_last else "│ "
|
||||
|
||||
# Format node header
|
||||
uri = node.get("uri", "")
|
||||
node_type = node.get("type", "unknown")
|
||||
metadata = node.get("metadata", {})
|
||||
|
||||
label = metadata.get("label") or metadata.get("title") or uri.split("/")[-1]
|
||||
type_str = node_type.capitalize()
|
||||
|
||||
if prefix:
|
||||
print(f"{prefix}{connector}{type_str}: {label}")
|
||||
else:
|
||||
print(f"{type_str}: {uri}")
|
||||
if metadata.get("title"):
|
||||
print(f" Title: \"{metadata['title']}\"")
|
||||
if metadata.get("format"):
|
||||
print(f" Type: {metadata['format']}")
|
||||
|
||||
new_prefix = prefix + continuation if prefix else " "
|
||||
|
||||
# Print content if available
|
||||
if show_content and "content" in node:
|
||||
content = node["content"]
|
||||
content_lines = content.split("\n")[:3] # Show first 3 lines
|
||||
for line in content_lines:
|
||||
if line.strip():
|
||||
truncated = line[:80] + "..." if len(line) > 80 else line
|
||||
print(f"{new_prefix}Content: \"{truncated}\"")
|
||||
break
|
||||
|
||||
# Print edges
|
||||
edges = node.get("edges", [])
|
||||
children = node.get("children", [])
|
||||
|
||||
total_items = len(edges) + len(children)
|
||||
current_item = 0
|
||||
|
||||
for edge in edges:
|
||||
current_item += 1
|
||||
is_last_item = (current_item == total_items)
|
||||
edge_connector = "└── " if is_last_item else "├── "
|
||||
print(f"{new_prefix}{edge_connector}Edge: {format_edge(edge)}")
|
||||
|
||||
# Print children recursively
|
||||
for i, child in enumerate(children):
|
||||
current_item += 1
|
||||
is_last_child = (i == len(children) - 1)
|
||||
print_tree(child, new_prefix, is_last_child, show_content)
|
||||
|
||||
|
||||
def print_json(node):
|
||||
"""Print node as JSON."""
|
||||
print(json.dumps(node, indent=2))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-show-document-hierarchy',
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'document_id',
|
||||
help='Document URI to show hierarchy for',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--api-url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Auth token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-U', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-C', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection (default: {default_collection})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default='default',
|
||||
help='Flow ID (default: default)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--show-content',
|
||||
action='store_true',
|
||||
help='Include blob/document content',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-content',
|
||||
type=int,
|
||||
default=200,
|
||||
help='Max chars to display per blob (default: 200)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--format',
|
||||
choices=['tree', 'json'],
|
||||
default='tree',
|
||||
help='Output format: tree (default), json',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
api = Api(args.api_url, token=args.token)
|
||||
socket = api.socket()
|
||||
|
||||
try:
|
||||
hierarchy = build_hierarchy(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
root_uri=args.document_id,
|
||||
api=api if args.show_content else None,
|
||||
show_content=args.show_content,
|
||||
max_content=args.max_content,
|
||||
)
|
||||
|
||||
if hierarchy is None:
|
||||
print(f"No data found for document: {args.document_id}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.format == 'json':
|
||||
print_json(hierarchy)
|
||||
else:
|
||||
print_tree(hierarchy, show_content=args.show_content)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
558
trustgraph-cli/trustgraph/cli/show_explain_trace.py
Normal file
558
trustgraph-cli/trustgraph/cli/show_explain_trace.py
Normal file
|
|
@ -0,0 +1,558 @@
|
|||
"""
|
||||
Show full explainability trace for a GraphRAG session.
|
||||
|
||||
Given a question/session URI, displays the complete cascade:
|
||||
Question -> Exploration -> Focus (edge selection) -> Synthesis (answer).
|
||||
|
||||
Examples:
|
||||
tg-show-explain-trace -U trustgraph -C default "urn:trustgraph:question:abc123"
|
||||
tg-show-explain-trace --max-answer 1000 "urn:trustgraph:question:abc123"
|
||||
tg-show-explain-trace --show-provenance "urn:trustgraph:question:abc123"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
# Predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
TG_EDGE_COUNT = TG + "edgeCount"
|
||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_CONTENT = TG + "content"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_REIFIES = TG + "reifies"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
|
||||
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
# Graphs
|
||||
RETRIEVAL_GRAPH = "urn:graph:retrieval"
|
||||
SOURCE_GRAPH = "urn:graph:source"
|
||||
|
||||
|
||||
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
|
||||
"""Query triples using the socket API."""
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
if s is not None:
|
||||
request["s"] = {"t": "i", "i": s}
|
||||
if p is not None:
|
||||
request["p"] = {"t": "i", "i": p}
|
||||
if o is not None:
|
||||
if isinstance(o, str):
|
||||
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
|
||||
request["o"] = {"t": "i", "i": o}
|
||||
else:
|
||||
request["o"] = {"t": "l", "v": o}
|
||||
elif isinstance(o, dict):
|
||||
request["o"] = o
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
|
||||
triples = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
p_val = extract_value(t.get("p", {}))
|
||||
o_val = extract_value(t.get("o", {}))
|
||||
triples.append((s_val, p_val, o_val))
|
||||
except Exception as e:
|
||||
print(f"Error querying triples: {e}", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def extract_value(term):
|
||||
"""Extract value from a term dict."""
|
||||
if not term:
|
||||
return ""
|
||||
|
||||
t = term.get("t") or term.get("type")
|
||||
|
||||
if t == "i":
|
||||
return term.get("i") or term.get("iri", "")
|
||||
elif t == "l":
|
||||
return term.get("v") or term.get("value", "")
|
||||
elif t == "t":
|
||||
# Quoted triple
|
||||
tr = term.get("tr") or term.get("triple", {})
|
||||
return {
|
||||
"s": extract_value(tr.get("s", {})),
|
||||
"p": extract_value(tr.get("p", {})),
|
||||
"o": extract_value(tr.get("o", {})),
|
||||
}
|
||||
|
||||
# Fallback for raw values
|
||||
if "i" in term:
|
||||
return term["i"]
|
||||
if "v" in term:
|
||||
return term["v"]
|
||||
|
||||
return str(term)
|
||||
|
||||
|
||||
def get_node_properties(socket, flow_id, user, collection, node_uri, graph=RETRIEVAL_GRAPH):
|
||||
"""Get all properties of a node as a dict."""
|
||||
triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=graph)
|
||||
props = {}
|
||||
for s, p, o in triples:
|
||||
if p not in props:
|
||||
props[p] = []
|
||||
props[p].append(o)
|
||||
return props
|
||||
|
||||
|
||||
def find_by_predicate_object(socket, flow_id, user, collection, predicate, obj, graph=RETRIEVAL_GRAPH):
|
||||
"""Find subjects where predicate = obj."""
|
||||
triples = query_triples(socket, flow_id, user, collection, p=predicate, o=obj, g=graph)
|
||||
return [s for s, p, o in triples]
|
||||
|
||||
|
||||
def get_label(socket, flow_id, user, collection, uri, label_cache):
|
||||
"""Get label for a URI, with caching."""
|
||||
if not isinstance(uri, str) or not (uri.startswith("http://") or uri.startswith("https://") or uri.startswith("urn:")):
|
||||
return uri
|
||||
|
||||
if uri in label_cache:
|
||||
return label_cache[uri]
|
||||
|
||||
triples = query_triples(socket, flow_id, user, collection, s=uri, p=RDFS_LABEL)
|
||||
for s, p, o in triples:
|
||||
label_cache[uri] = o
|
||||
return o
|
||||
|
||||
label_cache[uri] = uri
|
||||
return uri
|
||||
|
||||
|
||||
def get_document_content(api, user, doc_id, max_content):
|
||||
"""Fetch document content from librarian API."""
|
||||
try:
|
||||
library = api.library()
|
||||
content = library.get_document_content(user=user, id=doc_id)
|
||||
|
||||
# Try to decode as text
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
if len(text) > max_content:
|
||||
return text[:max_content] + "... [truncated]"
|
||||
return text
|
||||
except UnicodeDecodeError:
|
||||
return f"[Binary: {len(content)} bytes]"
|
||||
except Exception as e:
|
||||
return f"[Error fetching content: {e}]"
|
||||
|
||||
|
||||
def trace_edge_provenance(socket, flow_id, user, collection, edge_s, edge_p, edge_o, label_cache):
|
||||
"""Trace an edge back to its source document via reification."""
|
||||
# Build the quoted triple for lookup
|
||||
quoted_triple = {
|
||||
"t": "t",
|
||||
"tr": {
|
||||
"s": {"t": "i", "i": edge_s} if isinstance(edge_s, str) and (edge_s.startswith("http") or edge_s.startswith("urn:")) else {"t": "l", "v": edge_s},
|
||||
"p": {"t": "i", "i": edge_p},
|
||||
"o": {"t": "i", "i": edge_o} if isinstance(edge_o, str) and (edge_o.startswith("http") or edge_o.startswith("urn:")) else {"t": "l", "v": edge_o},
|
||||
}
|
||||
}
|
||||
|
||||
# Query: ?stmt tg:reifies <<edge>>
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 10,
|
||||
"streaming": False,
|
||||
"p": {"t": "i", "i": TG_REIFIES},
|
||||
"o": quoted_triple,
|
||||
"g": SOURCE_GRAPH,
|
||||
}
|
||||
|
||||
stmt_uris = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
if s_val:
|
||||
stmt_uris.append(s_val)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# For each statement, find wasDerivedFrom chain
|
||||
provenance_chains = []
|
||||
for stmt_uri in stmt_uris:
|
||||
chain = trace_provenance_chain(socket, flow_id, user, collection, stmt_uri, label_cache)
|
||||
if chain:
|
||||
provenance_chains.append(chain)
|
||||
|
||||
return provenance_chains
|
||||
|
||||
|
||||
def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_cache, max_depth=10):
|
||||
"""Trace prov:wasDerivedFrom chain from start_uri to root."""
|
||||
chain = []
|
||||
current = start_uri
|
||||
|
||||
for _ in range(max_depth):
|
||||
if not current:
|
||||
break
|
||||
|
||||
label = get_label(socket, flow_id, user, collection, current, label_cache)
|
||||
chain.append({"uri": current, "label": label})
|
||||
|
||||
# Get parent
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH
|
||||
)
|
||||
parent = None
|
||||
for s, p, o in triples:
|
||||
parent = o
|
||||
break
|
||||
|
||||
if not parent or parent == current:
|
||||
break
|
||||
current = parent
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
def format_provenance_chain(chain):
|
||||
"""Format a provenance chain for display."""
|
||||
if not chain:
|
||||
return ""
|
||||
labels = [item.get("label", item.get("uri", "?")) for item in chain]
|
||||
return " -> ".join(labels)
|
||||
|
||||
|
||||
def format_edge(edge, label_cache=None, socket=None, flow_id=None, user=None, collection=None):
|
||||
"""Format a quoted triple edge for display."""
|
||||
if not isinstance(edge, dict):
|
||||
return str(edge)
|
||||
|
||||
s = edge.get("s", "?")
|
||||
p = edge.get("p", "?")
|
||||
o = edge.get("o", "?")
|
||||
|
||||
# Get labels if available
|
||||
if label_cache and socket:
|
||||
s_label = get_label(socket, flow_id, user, collection, s, label_cache)
|
||||
p_label = get_label(socket, flow_id, user, collection, p, label_cache)
|
||||
o_label = get_label(socket, flow_id, user, collection, o, label_cache)
|
||||
else:
|
||||
# Shorten URIs for display
|
||||
s_label = s.split("/")[-1] if "/" in str(s) else s
|
||||
p_label = p.split("/")[-1] if "/" in str(p) else p
|
||||
o_label = o.split("/")[-1] if "/" in str(o) else o
|
||||
|
||||
return f"({s_label}, {p_label}, {o_label})"
|
||||
|
||||
|
||||
def build_trace(socket, flow_id, user, collection, question_id, api=None, show_provenance=False, max_answer=500):
|
||||
"""Build the full explainability trace for a question."""
|
||||
label_cache = {}
|
||||
|
||||
trace = {
|
||||
"question_id": question_id,
|
||||
"question": None,
|
||||
"time": None,
|
||||
"exploration": None,
|
||||
"focus": None,
|
||||
"synthesis": None,
|
||||
}
|
||||
|
||||
# Get question metadata
|
||||
props = get_node_properties(socket, flow_id, user, collection, question_id)
|
||||
trace["question"] = props.get(TG_QUERY, [None])[0]
|
||||
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
|
||||
|
||||
# Find exploration: ?exploration prov:wasGeneratedBy question_id
|
||||
exploration_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_GENERATED_BY, question_id
|
||||
)
|
||||
|
||||
if exploration_ids:
|
||||
exploration_id = exploration_ids[0]
|
||||
exploration_props = get_node_properties(socket, flow_id, user, collection, exploration_id)
|
||||
trace["exploration"] = {
|
||||
"id": exploration_id,
|
||||
"edge_count": exploration_props.get(TG_EDGE_COUNT, [None])[0],
|
||||
}
|
||||
|
||||
# Find focus: ?focus prov:wasDerivedFrom exploration_id
|
||||
focus_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, exploration_id
|
||||
)
|
||||
|
||||
if focus_ids:
|
||||
focus_id = focus_ids[0]
|
||||
focus_props = get_node_properties(socket, flow_id, user, collection, focus_id)
|
||||
|
||||
# Get selected edges
|
||||
edge_selection_uris = focus_props.get(TG_SELECTED_EDGE, [])
|
||||
selected_edges = []
|
||||
|
||||
for edge_sel_uri in edge_selection_uris:
|
||||
edge_sel_props = get_node_properties(socket, flow_id, user, collection, edge_sel_uri)
|
||||
edge = edge_sel_props.get(TG_EDGE, [None])[0]
|
||||
reasoning = edge_sel_props.get(TG_REASONING, [None])[0]
|
||||
|
||||
edge_info = {
|
||||
"edge": edge,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
|
||||
# Trace provenance if requested
|
||||
if show_provenance and isinstance(edge, dict):
|
||||
provenance = trace_edge_provenance(
|
||||
socket, flow_id, user, collection,
|
||||
edge.get("s", ""), edge.get("p", ""), edge.get("o", ""),
|
||||
label_cache
|
||||
)
|
||||
edge_info["provenance"] = provenance
|
||||
|
||||
selected_edges.append(edge_info)
|
||||
|
||||
trace["focus"] = {
|
||||
"id": focus_id,
|
||||
"selected_edges": selected_edges,
|
||||
}
|
||||
|
||||
# Find synthesis: ?synthesis prov:wasDerivedFrom focus_id
|
||||
synthesis_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, focus_id
|
||||
)
|
||||
|
||||
if synthesis_ids:
|
||||
synthesis_id = synthesis_ids[0]
|
||||
synthesis_props = get_node_properties(socket, flow_id, user, collection, synthesis_id)
|
||||
|
||||
# Get content directly or via document reference
|
||||
content = synthesis_props.get(TG_CONTENT, [None])[0]
|
||||
doc_id = synthesis_props.get(TG_DOCUMENT, [None])[0]
|
||||
|
||||
if not content and doc_id and api:
|
||||
content = get_document_content(api, user, doc_id, max_answer)
|
||||
elif content and len(content) > max_answer:
|
||||
content = content[:max_answer] + "... [truncated]"
|
||||
|
||||
trace["synthesis"] = {
|
||||
"id": synthesis_id,
|
||||
"document_id": doc_id,
|
||||
"answer": content,
|
||||
}
|
||||
|
||||
# Store label cache for formatting
|
||||
trace["_label_cache"] = label_cache
|
||||
|
||||
return trace
|
||||
|
||||
|
||||
def print_text(trace, show_provenance=False):
|
||||
"""Print trace in text format."""
|
||||
label_cache = trace.get("_label_cache", {})
|
||||
|
||||
print(f"=== GraphRAG Session: {trace['question_id']} ===")
|
||||
print()
|
||||
|
||||
if trace["question"]:
|
||||
print(f"Question: {trace['question']}")
|
||||
if trace["time"]:
|
||||
print(f"Time: {trace['time']}")
|
||||
print()
|
||||
|
||||
# Exploration
|
||||
print("--- Exploration ---")
|
||||
exploration = trace.get("exploration")
|
||||
if exploration:
|
||||
edge_count = exploration.get("edge_count", "?")
|
||||
print(f"Retrieved {edge_count} edges from knowledge graph")
|
||||
else:
|
||||
print("No exploration data found")
|
||||
print()
|
||||
|
||||
# Focus
|
||||
print("--- Focus (Edge Selection) ---")
|
||||
focus = trace.get("focus")
|
||||
if focus:
|
||||
edges = focus.get("selected_edges", [])
|
||||
print(f"Selected {len(edges)} edges:")
|
||||
print()
|
||||
|
||||
for i, edge_info in enumerate(edges, 1):
|
||||
edge = edge_info.get("edge")
|
||||
reasoning = edge_info.get("reasoning")
|
||||
|
||||
if edge:
|
||||
edge_str = format_edge(edge)
|
||||
print(f" {i}. {edge_str}")
|
||||
|
||||
if reasoning:
|
||||
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
|
||||
print(f" Reasoning: {r_short}")
|
||||
|
||||
if show_provenance:
|
||||
provenance = edge_info.get("provenance", [])
|
||||
for chain in provenance:
|
||||
chain_str = format_provenance_chain(chain)
|
||||
if chain_str:
|
||||
print(f" Source: {chain_str}")
|
||||
|
||||
print()
|
||||
else:
|
||||
print("No focus data found")
|
||||
print()
|
||||
|
||||
# Synthesis
|
||||
print("--- Synthesis ---")
|
||||
synthesis = trace.get("synthesis")
|
||||
if synthesis:
|
||||
answer = synthesis.get("answer")
|
||||
if answer:
|
||||
print("Answer:")
|
||||
# Indent the answer
|
||||
for line in answer.split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No answer content found")
|
||||
else:
|
||||
print("No synthesis data found")
|
||||
|
||||
|
||||
def print_json(trace):
|
||||
"""Print trace as JSON."""
|
||||
# Remove internal cache before printing
|
||||
output = {k: v for k, v in trace.items() if not k.startswith("_")}
|
||||
print(json.dumps(output, indent=2))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-show-explain-trace',
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'question_id',
|
||||
help='Question/session URI to show trace for',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--api-url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Auth token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-U', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-C', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection (default: {default_collection})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default='default',
|
||||
help='Flow ID (default: default)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-answer',
|
||||
type=int,
|
||||
default=500,
|
||||
help='Max chars for answer display (default: 500)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--show-provenance',
|
||||
action='store_true',
|
||||
help='Also trace edges back to source documents',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--format',
|
||||
choices=['text', 'json'],
|
||||
default='text',
|
||||
help='Output format: text (default), json',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
api = Api(args.api_url, token=args.token)
|
||||
socket = api.socket()
|
||||
|
||||
try:
|
||||
trace = build_trace(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
question_id=args.question_id,
|
||||
api=api,
|
||||
show_provenance=args.show_provenance,
|
||||
max_answer=args.max_answer,
|
||||
)
|
||||
|
||||
if args.format == 'json':
|
||||
print_json(trace)
|
||||
else:
|
||||
print_text(trace, show_provenance=args.show_provenance)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue