mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 17:36:23 +02:00
Add unified explainability support and librarian storage for (#693)
Add unified explainability support and librarian storage for all retrieval engines Implements consistent explainability/provenance tracking across GraphRAG, DocumentRAG, and Agent retrieval engines. All large content (answers, thoughts, observations) is now stored in librarian rather than as inline literals in the knowledge graph. Explainability API: - New explainability.py module with entity classes (Question, Exploration, Focus, Synthesis, Analysis, Conclusion) and ExplainabilityClient - Quiescence-based eventual consistency handling for trace fetching - Content fetching from librarian with retry logic CLI updates: - tg-invoke-graph-rag -x/--explainable flag returns explain_id - tg-invoke-document-rag -x/--explainable flag returns explain_id - tg-invoke-agent -x/--explainable flag returns explain_id - tg-list-explain-traces uses new explainability API - tg-show-explain-trace handles all three trace types Agent provenance: - Records session, iterations (think/act/observe), and conclusion - Stores thoughts and observations in librarian with document references - New predicates: tg:thoughtDocument, tg:observationDocument DocumentRAG provenance: - Records question, exploration (chunk retrieval), and synthesis - Stores answers in librarian with document references Schema changes: - AgentResponse: added explain_id, explain_graph fields - RetrievalResponse: added explain_id, explain_graph fields - agent_iteration_triples: supports thought_document_id, observation_document_id Update tests.
This commit is contained in:
parent
aecf00f040
commit
35128ff019
24 changed files with 2736 additions and 846 deletions
|
|
@ -4,8 +4,19 @@ Uses the agent service to answer a question
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
ProvenanceEvent,
|
||||
Question,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
AgentThought,
|
||||
AgentObservation,
|
||||
AgentAnswer,
|
||||
)
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
|
@ -97,11 +108,148 @@ def output(text, prefix="> ", width=78):
|
|||
)
|
||||
print(out)
|
||||
|
||||
def question_explainable(
|
||||
url, question_text, flow_id, user, collection,
|
||||
state=None, group=None, verbose=False, token=None, debug=False
|
||||
):
|
||||
"""Execute agent with explainability - shows provenance events inline."""
|
||||
api = Api(url=url, token=token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
|
||||
|
||||
try:
|
||||
# Track last chunk type for formatting
|
||||
last_chunk_type = None
|
||||
current_outputter = None
|
||||
|
||||
# Stream agent with explainability - process events as they arrive
|
||||
for item in flow.agent_explain(
|
||||
question=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
state=state,
|
||||
group=group,
|
||||
):
|
||||
if isinstance(item, AgentThought):
|
||||
if last_chunk_type != "thought":
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
print() # Blank line between message types
|
||||
if verbose:
|
||||
current_outputter = Outputter(width=78, prefix="\U0001f914 ")
|
||||
current_outputter.__enter__()
|
||||
last_chunk_type = "thought"
|
||||
if current_outputter:
|
||||
current_outputter.output(item.content)
|
||||
if current_outputter.word_buffer:
|
||||
print(current_outputter.word_buffer, end="", flush=True)
|
||||
current_outputter.column += len(current_outputter.word_buffer)
|
||||
current_outputter.word_buffer = ""
|
||||
|
||||
elif isinstance(item, AgentObservation):
|
||||
if last_chunk_type != "observation":
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
print()
|
||||
if verbose:
|
||||
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
|
||||
current_outputter.__enter__()
|
||||
last_chunk_type = "observation"
|
||||
if current_outputter:
|
||||
current_outputter.output(item.content)
|
||||
if current_outputter.word_buffer:
|
||||
print(current_outputter.word_buffer, end="", flush=True)
|
||||
current_outputter.column += len(current_outputter.word_buffer)
|
||||
current_outputter.word_buffer = ""
|
||||
|
||||
elif isinstance(item, AgentAnswer):
|
||||
if last_chunk_type != "answer":
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
print()
|
||||
last_chunk_type = "answer"
|
||||
# Print answer content directly
|
||||
print(item.content, end="", flush=True)
|
||||
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
# Process provenance event immediately
|
||||
prov_id = item.explain_id
|
||||
explain_graph = item.explain_graph or "urn:graph:retrieval"
|
||||
|
||||
entity = explain_client.fetch_entity(
|
||||
prov_id,
|
||||
graph=explain_graph,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
if debug:
|
||||
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Display based on entity type
|
||||
if isinstance(entity, Question):
|
||||
print(f"\n [session] {prov_id}", file=sys.stderr)
|
||||
if entity.query:
|
||||
print(f" Query: {entity.query}", file=sys.stderr)
|
||||
if entity.timestamp:
|
||||
print(f" Time: {entity.timestamp}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Analysis):
|
||||
print(f"\n [iteration] {prov_id}", file=sys.stderr)
|
||||
if entity.thought:
|
||||
thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought
|
||||
print(f" Thought: {thought_short}", file=sys.stderr)
|
||||
if entity.action:
|
||||
print(f" Action: {entity.action}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Conclusion):
|
||||
print(f"\n [conclusion] {prov_id}", file=sys.stderr)
|
||||
if entity.answer:
|
||||
print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
|
||||
|
||||
# Close any remaining outputter
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
|
||||
# Final newline if we ended with answer
|
||||
if last_chunk_type == "answer":
|
||||
print()
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
|
||||
def question(
|
||||
url, question, flow_id, user, collection,
|
||||
plan=None, state=None, group=None, verbose=False, streaming=True,
|
||||
token=None
|
||||
token=None, explainable=False, debug=False
|
||||
):
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
if explainable:
|
||||
question_explainable(
|
||||
url=url,
|
||||
question_text=question,
|
||||
flow_id=flow_id,
|
||||
user=user,
|
||||
collection=collection,
|
||||
state=state,
|
||||
group=group,
|
||||
verbose=verbose,
|
||||
token=token,
|
||||
debug=debug
|
||||
)
|
||||
return
|
||||
|
||||
if verbose:
|
||||
output(wrap(question), "\U00002753 ")
|
||||
|
|
@ -270,6 +418,18 @@ def main():
|
|||
help=f'Disable streaming (use legacy mode)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--explainable',
|
||||
action='store_true',
|
||||
help='Show provenance events: Session, Iterations, Conclusion (implies streaming)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Show debug output for troubleshooting'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -286,6 +446,8 @@ def main():
|
|||
verbose = args.verbose,
|
||||
streaming = not args.no_streaming,
|
||||
token = args.token,
|
||||
explainable = args.explainable,
|
||||
debug = args.debug,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,16 @@ Uses the DocumentRAG service to answer a question
|
|||
|
||||
import argparse
|
||||
import os
|
||||
from trustgraph.api import Api
|
||||
import sys
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
RAGChunk,
|
||||
ProvenanceEvent,
|
||||
Question,
|
||||
Exploration,
|
||||
Synthesis,
|
||||
)
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
|
@ -12,7 +21,90 @@ default_user = 'trustgraph'
|
|||
default_collection = 'default'
|
||||
default_doc_limit = 10
|
||||
|
||||
def question(url, flow_id, question, user, collection, doc_limit, streaming=True, token=None):
|
||||
|
||||
def question_explainable(
|
||||
url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False
|
||||
):
|
||||
"""Execute document RAG with explainability - shows provenance events inline."""
|
||||
api = Api(url=url, token=token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
|
||||
|
||||
try:
|
||||
# Stream DocumentRAG with explainability - process events as they arrive
|
||||
for item in flow.document_rag_explain(
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
print(item.content, end="", flush=True)
|
||||
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
# Process provenance event immediately
|
||||
prov_id = item.explain_id
|
||||
explain_graph = item.explain_graph or "urn:graph:retrieval"
|
||||
|
||||
entity = explain_client.fetch_entity(
|
||||
prov_id,
|
||||
graph=explain_graph,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
if debug:
|
||||
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Display based on entity type
|
||||
if isinstance(entity, Question):
|
||||
print(f"\n [question] {prov_id}", file=sys.stderr)
|
||||
if entity.query:
|
||||
print(f" Query: {entity.query}", file=sys.stderr)
|
||||
if entity.timestamp:
|
||||
print(f" Time: {entity.timestamp}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Exploration):
|
||||
print(f"\n [exploration] {prov_id}", file=sys.stderr)
|
||||
if entity.chunk_count:
|
||||
print(f" Chunks retrieved: {entity.chunk_count}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
if entity.content:
|
||||
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
|
||||
|
||||
print() # Final newline
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
|
||||
def question(
|
||||
url, flow_id, question_text, user, collection, doc_limit,
|
||||
streaming=True, token=None, explainable=False, debug=False
|
||||
):
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
if explainable:
|
||||
question_explainable(
|
||||
url=url,
|
||||
flow_id=flow_id,
|
||||
question_text=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
token=token,
|
||||
debug=debug
|
||||
)
|
||||
return
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -24,7 +116,7 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
|
|||
|
||||
try:
|
||||
response = flow.document_rag(
|
||||
query=question,
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
|
|
@ -42,13 +134,14 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
|
|||
# Use REST API for non-streaming
|
||||
flow = api.flow().id(flow_id)
|
||||
resp = flow.document_rag(
|
||||
query=question,
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
)
|
||||
print(resp)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -105,6 +198,18 @@ def main():
|
|||
help='Disable streaming (use non-streaming mode)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--explainable',
|
||||
action='store_true',
|
||||
help='Show provenance events: Question, Exploration, Synthesis (implies streaming)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Show debug output for troubleshooting'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -112,12 +217,14 @@ def main():
|
|||
question(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
question=args.question,
|
||||
question_text=args.question,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
doc_limit=args.doc_limit,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,16 @@ import os
|
|||
import sys
|
||||
import websockets
|
||||
import asyncio
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
RAGChunk,
|
||||
ProvenanceEvent,
|
||||
Question,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
)
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
|
@ -602,18 +611,111 @@ async def _question_explainable(
|
|||
print() # Final newline
|
||||
|
||||
|
||||
def _question_explainable_api(
|
||||
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, token=None, debug=False
|
||||
):
|
||||
"""Execute graph RAG with explainability using the new API classes."""
|
||||
api = Api(url=url, token=token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
|
||||
|
||||
try:
|
||||
# Stream GraphRAG with explainability - process events as they arrive
|
||||
for item in flow.graph_rag_explain(
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_subgraph_count=5,
|
||||
max_entity_distance=max_path_length,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
print(item.content, end="", flush=True)
|
||||
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
# Process provenance event immediately
|
||||
prov_id = item.explain_id
|
||||
explain_graph = item.explain_graph or "urn:graph:retrieval"
|
||||
|
||||
entity = explain_client.fetch_entity(
|
||||
prov_id,
|
||||
graph=explain_graph,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
if debug:
|
||||
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Display based on entity type
|
||||
if isinstance(entity, Question):
|
||||
print(f"\n [question] {prov_id}", file=sys.stderr)
|
||||
if entity.query:
|
||||
print(f" Query: {entity.query}", file=sys.stderr)
|
||||
if entity.timestamp:
|
||||
print(f" Time: {entity.timestamp}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Exploration):
|
||||
print(f"\n [exploration] {prov_id}", file=sys.stderr)
|
||||
if entity.edge_count:
|
||||
print(f" Edges explored: {entity.edge_count}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Focus):
|
||||
print(f"\n [focus] {prov_id}", file=sys.stderr)
|
||||
if entity.selected_edge_uris:
|
||||
print(f" Focused on {len(entity.selected_edge_uris)} edge(s)", file=sys.stderr)
|
||||
|
||||
# Fetch full focus with edge details
|
||||
focus_full = explain_client.fetch_focus_with_edges(
|
||||
prov_id,
|
||||
graph=explain_graph,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
if focus_full and focus_full.edge_selections:
|
||||
for edge_sel in focus_full.edge_selections:
|
||||
if edge_sel.edge:
|
||||
# Resolve labels for edge components
|
||||
s_label, p_label, o_label = explain_client.resolve_edge_labels(
|
||||
edge_sel.edge, user, collection
|
||||
)
|
||||
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reason: {r_short}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
if entity.content:
|
||||
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
|
||||
|
||||
print() # Final newline
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
|
||||
def question(
|
||||
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, streaming=True, token=None,
|
||||
explainable=False, debug=False
|
||||
):
|
||||
|
||||
# Explainable mode uses direct websocket to capture provenance events
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
if explainable:
|
||||
asyncio.run(_question_explainable(
|
||||
_question_explainable_api(
|
||||
url=url,
|
||||
flow_id=flow_id,
|
||||
question=question,
|
||||
question_text=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
|
|
@ -622,7 +724,7 @@ def question(
|
|||
max_path_length=max_path_length,
|
||||
token=token,
|
||||
debug=debug
|
||||
))
|
||||
)
|
||||
return
|
||||
|
||||
# Create API client
|
||||
|
|
|
|||
|
|
@ -14,180 +14,17 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from tabulate import tabulate
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import Api, ExplainabilityClient
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
# Predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_EXPLORATION = TG + "Exploration"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
|
||||
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
|
||||
# Retrieval graph
|
||||
RETRIEVAL_GRAPH = "urn:graph:retrieval"
|
||||
|
||||
|
||||
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
|
||||
"""Query triples using the socket API."""
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
if s is not None:
|
||||
request["s"] = {"t": "i", "i": s}
|
||||
if p is not None:
|
||||
request["p"] = {"t": "i", "i": p}
|
||||
if o is not None:
|
||||
if isinstance(o, str):
|
||||
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
|
||||
request["o"] = {"t": "i", "i": o}
|
||||
else:
|
||||
request["o"] = {"t": "l", "v": o}
|
||||
elif isinstance(o, dict):
|
||||
request["o"] = o
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
|
||||
triples = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
p_val = extract_value(t.get("p", {}))
|
||||
o_val = extract_value(t.get("o", {}))
|
||||
triples.append((s_val, p_val, o_val))
|
||||
except Exception as e:
|
||||
print(f"Error querying triples: {e}", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def extract_value(term):
|
||||
"""Extract value from a term dict."""
|
||||
if not term:
|
||||
return ""
|
||||
|
||||
t = term.get("t") or term.get("type")
|
||||
|
||||
if t == "i":
|
||||
return term.get("i") or term.get("iri", "")
|
||||
elif t == "l":
|
||||
return term.get("v") or term.get("value", "")
|
||||
elif t == "t":
|
||||
# Quoted triple
|
||||
tr = term.get("tr") or term.get("triple", {})
|
||||
return {
|
||||
"s": extract_value(tr.get("s", {})),
|
||||
"p": extract_value(tr.get("p", {})),
|
||||
"o": extract_value(tr.get("o", {})),
|
||||
}
|
||||
|
||||
# Fallback for raw values
|
||||
if "i" in term:
|
||||
return term["i"]
|
||||
if "v" in term:
|
||||
return term["v"]
|
||||
|
||||
return str(term)
|
||||
|
||||
|
||||
def get_timestamp(socket, flow_id, user, collection, question_id):
|
||||
"""Get timestamp for a question."""
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=question_id, p=PROV_STARTED_AT_TIME, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
for s, p, o in triples:
|
||||
return o
|
||||
return ""
|
||||
|
||||
|
||||
def get_session_type(socket, flow_id, user, collection, session_id):
|
||||
"""
|
||||
Get the type of session (Agent or GraphRAG).
|
||||
|
||||
Both have tg:Question type, so we distinguish by URI pattern
|
||||
or by checking what's derived from it.
|
||||
"""
|
||||
# Fast path: check URI pattern
|
||||
if session_id.startswith("urn:trustgraph:agent:"):
|
||||
return "Agent"
|
||||
if session_id.startswith("urn:trustgraph:question:"):
|
||||
return "GraphRAG"
|
||||
|
||||
# Check what's derived from this entity
|
||||
derived = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=PROV_WAS_DERIVED_FROM, o=session_id, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
generated = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=PROV_WAS_GENERATED_BY, o=session_id, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
|
||||
for s, p, o in derived + generated:
|
||||
child_types = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=s, p=RDF_TYPE, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
for _, _, child_type in child_types:
|
||||
if child_type == TG_ANALYSIS:
|
||||
return "Agent"
|
||||
if child_type == TG_EXPLORATION:
|
||||
return "GraphRAG"
|
||||
|
||||
return "GraphRAG"
|
||||
|
||||
|
||||
def list_sessions(socket, flow_id, user, collection, limit):
|
||||
"""List all explainability sessions (GraphRAG and Agent) by finding questions."""
|
||||
# Query for all triples with predicate = tg:query
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=TG_QUERY, g=RETRIEVAL_GRAPH, limit=limit
|
||||
)
|
||||
|
||||
sessions = []
|
||||
for question_id, _, query_text in triples:
|
||||
# Get timestamp if available
|
||||
timestamp = get_timestamp(socket, flow_id, user, collection, question_id)
|
||||
# Get session type (Agent or GraphRAG)
|
||||
session_type = get_session_type(socket, flow_id, user, collection, question_id)
|
||||
|
||||
sessions.append({
|
||||
"id": question_id,
|
||||
"type": session_type,
|
||||
"question": query_text,
|
||||
"time": timestamp,
|
||||
})
|
||||
|
||||
# Sort by timestamp (newest first) if available
|
||||
sessions.sort(key=lambda x: x.get("time", ""), reverse=True)
|
||||
|
||||
return sessions
|
||||
|
||||
|
||||
def truncate_text(text, max_len=60):
|
||||
"""Truncate text to max length with ellipsis."""
|
||||
if not text:
|
||||
|
|
@ -277,16 +114,42 @@ def main():
|
|||
try:
|
||||
api = Api(args.api_url, token=args.token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(args.flow_id)
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
try:
|
||||
sessions = list_sessions(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
# List all sessions using the API
|
||||
questions = explain_client.list_sessions(
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
# Convert to output format
|
||||
sessions = []
|
||||
for q in questions:
|
||||
session_type = explain_client.detect_session_type(
|
||||
q.uri,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection
|
||||
)
|
||||
|
||||
# Map type names
|
||||
type_display = {
|
||||
"graphrag": "GraphRAG",
|
||||
"docrag": "DocRAG",
|
||||
"agent": "Agent",
|
||||
}.get(session_type, session_type.title())
|
||||
|
||||
sessions.append({
|
||||
"id": q.uri,
|
||||
"type": type_display,
|
||||
"question": q.query,
|
||||
"time": q.timestamp,
|
||||
})
|
||||
|
||||
if args.format == 'json':
|
||||
print_json(sessions)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -291,42 +291,25 @@ def query_graph(
|
|||
):
|
||||
"""Query the triple store with pattern matching.
|
||||
|
||||
Uses the WebSocket API's raw streaming mode for efficient delivery of results.
|
||||
Uses the API's triples_query_stream for efficient streaming delivery.
|
||||
"""
|
||||
socket = Api(url, token=token).socket()
|
||||
|
||||
# Build request dict directly (bypassing triples_query_stream's string conversion)
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": True,
|
||||
"batch-size": batch_size,
|
||||
}
|
||||
|
||||
# Add term dicts for s/p/o (None means wildcard)
|
||||
if subject is not None:
|
||||
request["s"] = subject
|
||||
if predicate is not None:
|
||||
request["p"] = predicate
|
||||
if obj is not None:
|
||||
request["o"] = obj
|
||||
if graph is not None:
|
||||
request["g"] = graph
|
||||
flow = socket.flow(flow_id)
|
||||
|
||||
all_triples = []
|
||||
|
||||
try:
|
||||
# Use raw streaming mode - yields response dicts directly
|
||||
for response in socket._send_request_sync(
|
||||
"triples", flow_id, request, streaming_raw=True
|
||||
# Use triples_query_stream - accepts Term dicts directly
|
||||
for triples in flow.triples_query_stream(
|
||||
s=subject,
|
||||
p=predicate,
|
||||
o=obj,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=limit,
|
||||
batch_size=batch_size,
|
||||
):
|
||||
# Response may have triples in different locations depending on format
|
||||
if isinstance(response, dict):
|
||||
triples = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triples = response
|
||||
|
||||
if not isinstance(triples, list):
|
||||
triples = [triples] if triples else []
|
||||
|
||||
|
|
|
|||
|
|
@ -18,228 +18,99 @@ import argparse
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
Question,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
)
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
# Predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
TG_EDGE_COUNT = TG + "edgeCount"
|
||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_CONTENT = TG + "content"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_REIFIES = TG + "reifies"
|
||||
# Explainability entity types
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_EXPLORATION = TG + "Exploration"
|
||||
TG_FOCUS = TG + "Focus"
|
||||
TG_SYNTHESIS = TG + "Synthesis"
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_CONCLUSION = TG + "Conclusion"
|
||||
|
||||
# Agent predicates
|
||||
TG_THOUGHT = TG + "thought"
|
||||
TG_ACTION = TG + "action"
|
||||
TG_ARGUMENTS = TG + "arguments"
|
||||
TG_OBSERVATION = TG + "observation"
|
||||
TG_ANSWER = TG + "answer"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
|
||||
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
# Graphs
|
||||
RETRIEVAL_GRAPH = "urn:graph:retrieval"
|
||||
SOURCE_GRAPH = "urn:graph:source"
|
||||
|
||||
|
||||
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
|
||||
"""Query triples using the socket API."""
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
if s is not None:
|
||||
request["s"] = {"t": "i", "i": s}
|
||||
if p is not None:
|
||||
request["p"] = {"t": "i", "i": p}
|
||||
if o is not None:
|
||||
if isinstance(o, str):
|
||||
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
|
||||
request["o"] = {"t": "i", "i": o}
|
||||
else:
|
||||
request["o"] = {"t": "l", "v": o}
|
||||
elif isinstance(o, dict):
|
||||
request["o"] = o
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
|
||||
triples = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
p_val = extract_value(t.get("p", {}))
|
||||
o_val = extract_value(t.get("o", {}))
|
||||
triples.append((s_val, p_val, o_val))
|
||||
except Exception as e:
|
||||
print(f"Error querying triples: {e}", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
# Provenance predicates for edge tracing
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_REIFIES = TG + "reifies"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
|
||||
|
||||
def extract_value(term):
|
||||
"""Extract value from a term dict."""
|
||||
if not term:
|
||||
return ""
|
||||
def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client):
|
||||
"""
|
||||
Trace an edge back to its source document via reification.
|
||||
|
||||
t = term.get("t") or term.get("type")
|
||||
Args:
|
||||
flow: SocketFlowInstance
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
edge: Dict with s, p, o keys
|
||||
label_cache: Dict for caching labels
|
||||
explain_client: ExplainabilityClient for label resolution
|
||||
|
||||
if t == "i":
|
||||
return term.get("i") or term.get("iri", "")
|
||||
elif t == "l":
|
||||
return term.get("v") or term.get("value", "")
|
||||
elif t == "t":
|
||||
# Quoted triple
|
||||
tr = term.get("tr") or term.get("triple", {})
|
||||
return {
|
||||
"s": extract_value(tr.get("s", {})),
|
||||
"p": extract_value(tr.get("p", {})),
|
||||
"o": extract_value(tr.get("o", {})),
|
||||
}
|
||||
Returns:
|
||||
List of provenance chains, each chain is list of {uri, label}
|
||||
"""
|
||||
edge_s = edge.get("s", "")
|
||||
edge_p = edge.get("p", "")
|
||||
edge_o = edge.get("o", "")
|
||||
|
||||
# Fallback for raw values
|
||||
if "i" in term:
|
||||
return term["i"]
|
||||
if "v" in term:
|
||||
return term["v"]
|
||||
# Build quoted triple for lookup
|
||||
def build_term(val):
|
||||
if isinstance(val, str) and (val.startswith("http") or val.startswith("urn:")):
|
||||
return {"t": "i", "i": val}
|
||||
return {"t": "l", "v": str(val)}
|
||||
|
||||
return str(term)
|
||||
|
||||
|
||||
def get_node_properties(socket, flow_id, user, collection, node_uri, graph=RETRIEVAL_GRAPH):
|
||||
"""Get all properties of a node as a dict."""
|
||||
triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=graph)
|
||||
props = {}
|
||||
for s, p, o in triples:
|
||||
if p not in props:
|
||||
props[p] = []
|
||||
props[p].append(o)
|
||||
return props
|
||||
|
||||
|
||||
def find_by_predicate_object(socket, flow_id, user, collection, predicate, obj, graph=RETRIEVAL_GRAPH):
|
||||
"""Find subjects where predicate = obj."""
|
||||
triples = query_triples(socket, flow_id, user, collection, p=predicate, o=obj, g=graph)
|
||||
return [s for s, p, o in triples]
|
||||
|
||||
|
||||
def get_label(socket, flow_id, user, collection, uri, label_cache):
|
||||
"""Get label for a URI, with caching."""
|
||||
if not isinstance(uri, str) or not (uri.startswith("http://") or uri.startswith("https://") or uri.startswith("urn:")):
|
||||
return uri
|
||||
|
||||
if uri in label_cache:
|
||||
return label_cache[uri]
|
||||
|
||||
triples = query_triples(socket, flow_id, user, collection, s=uri, p=RDFS_LABEL)
|
||||
for s, p, o in triples:
|
||||
label_cache[uri] = o
|
||||
return o
|
||||
|
||||
label_cache[uri] = uri
|
||||
return uri
|
||||
|
||||
|
||||
def get_document_content(api, user, doc_id, max_content):
|
||||
"""Fetch document content from librarian API."""
|
||||
try:
|
||||
library = api.library()
|
||||
content = library.get_document_content(user=user, id=doc_id)
|
||||
|
||||
# Try to decode as text
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
if len(text) > max_content:
|
||||
return text[:max_content] + "... [truncated]"
|
||||
return text
|
||||
except UnicodeDecodeError:
|
||||
return f"[Binary: {len(content)} bytes]"
|
||||
except Exception as e:
|
||||
return f"[Error fetching content: {e}]"
|
||||
|
||||
|
||||
def trace_edge_provenance(socket, flow_id, user, collection, edge_s, edge_p, edge_o, label_cache):
|
||||
"""Trace an edge back to its source document via reification."""
|
||||
# Build the quoted triple for lookup
|
||||
quoted_triple = {
|
||||
"t": "t",
|
||||
"tr": {
|
||||
"s": {"t": "i", "i": edge_s} if isinstance(edge_s, str) and (edge_s.startswith("http") or edge_s.startswith("urn:")) else {"t": "l", "v": edge_s},
|
||||
"p": {"t": "i", "i": edge_p},
|
||||
"o": {"t": "i", "i": edge_o} if isinstance(edge_o, str) and (edge_o.startswith("http") or edge_o.startswith("urn:")) else {"t": "l", "v": edge_o},
|
||||
"s": build_term(edge_s),
|
||||
"p": build_term(edge_p),
|
||||
"o": build_term(edge_o),
|
||||
}
|
||||
}
|
||||
|
||||
# Query: ?stmt tg:reifies <<edge>>
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 10,
|
||||
"streaming": False,
|
||||
"p": {"t": "i", "i": TG_REIFIES},
|
||||
"o": quoted_triple,
|
||||
"g": SOURCE_GRAPH,
|
||||
}
|
||||
|
||||
stmt_uris = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
if s_val:
|
||||
stmt_uris.append(s_val)
|
||||
results = flow.triples_query(
|
||||
p=TG_REIFIES,
|
||||
o=quoted_triple,
|
||||
g=SOURCE_GRAPH,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
# For each statement, find wasDerivedFrom chain
|
||||
# Extract statement URIs
|
||||
stmt_uris = []
|
||||
for t in results:
|
||||
s_term = t.get("s", {})
|
||||
s_val = s_term.get("i") or s_term.get("v", "")
|
||||
if s_val:
|
||||
stmt_uris.append(s_val)
|
||||
|
||||
# For each statement, trace wasDerivedFrom chain
|
||||
provenance_chains = []
|
||||
for stmt_uri in stmt_uris:
|
||||
chain = trace_provenance_chain(socket, flow_id, user, collection, stmt_uri, label_cache)
|
||||
chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client)
|
||||
if chain:
|
||||
provenance_chains.append(chain)
|
||||
|
||||
return provenance_chains
|
||||
|
||||
|
||||
def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_cache, max_depth=10):
|
||||
def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10):
|
||||
"""Trace prov:wasDerivedFrom chain from start_uri to root."""
|
||||
chain = []
|
||||
current = start_uri
|
||||
|
|
@ -248,17 +119,32 @@ def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_c
|
|||
if not current:
|
||||
break
|
||||
|
||||
label = get_label(socket, flow_id, user, collection, current, label_cache)
|
||||
# Get label
|
||||
if current in label_cache:
|
||||
label = label_cache[current]
|
||||
else:
|
||||
label = explain_client.resolve_label(current, user, collection)
|
||||
label_cache[current] = label
|
||||
|
||||
chain.append({"uri": current, "label": label})
|
||||
|
||||
# Get parent
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH
|
||||
)
|
||||
# Get parent via wasDerivedFrom
|
||||
try:
|
||||
results = flow.triples_query(
|
||||
s=current,
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
g=SOURCE_GRAPH,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=1
|
||||
)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
parent = None
|
||||
for s, p, o in triples:
|
||||
parent = o
|
||||
for t in results:
|
||||
o_term = t.get("o", {})
|
||||
parent = o_term.get("i") or o_term.get("v", "")
|
||||
break
|
||||
|
||||
if not parent or parent == current:
|
||||
|
|
@ -276,331 +162,24 @@ def format_provenance_chain(chain):
|
|||
return " -> ".join(labels)
|
||||
|
||||
|
||||
def format_edge(edge, label_cache=None, socket=None, flow_id=None, user=None, collection=None):
|
||||
"""Format a quoted triple edge for display."""
|
||||
if not isinstance(edge, dict):
|
||||
return str(edge)
|
||||
def print_graphrag_text(trace, explain_client, flow, user, collection, show_provenance=False):
|
||||
"""Print GraphRAG trace in text format."""
|
||||
question = trace.get("question")
|
||||
|
||||
s = edge.get("s", "?")
|
||||
p = edge.get("p", "?")
|
||||
o = edge.get("o", "?")
|
||||
|
||||
# Get labels if available
|
||||
if label_cache and socket:
|
||||
s_label = get_label(socket, flow_id, user, collection, s, label_cache)
|
||||
p_label = get_label(socket, flow_id, user, collection, p, label_cache)
|
||||
o_label = get_label(socket, flow_id, user, collection, o, label_cache)
|
||||
else:
|
||||
# Shorten URIs for display
|
||||
s_label = s.split("/")[-1] if "/" in str(s) else s
|
||||
p_label = p.split("/")[-1] if "/" in str(p) else p
|
||||
o_label = o.split("/")[-1] if "/" in str(o) else o
|
||||
|
||||
return f"({s_label}, {p_label}, {o_label})"
|
||||
|
||||
|
||||
def detect_trace_type(socket, flow_id, user, collection, entity_id):
|
||||
"""
|
||||
Detect whether an entity is an agent Question or GraphRAG Question.
|
||||
|
||||
Both have rdf:type = tg:Question, so we distinguish by checking
|
||||
what's derived from it:
|
||||
- Agent: has tg:Analysis or tg:Conclusion derived
|
||||
- GraphRAG: has tg:Exploration derived
|
||||
|
||||
Also checks URI pattern as fallback:
|
||||
- urn:trustgraph:agent: -> agent
|
||||
- urn:trustgraph:question: -> graphrag
|
||||
|
||||
Returns:
|
||||
"agent" or "graphrag"
|
||||
"""
|
||||
# Check URI pattern first (fast path)
|
||||
if entity_id.startswith("urn:trustgraph:agent:"):
|
||||
return "agent"
|
||||
if entity_id.startswith("urn:trustgraph:question:"):
|
||||
return "graphrag"
|
||||
|
||||
# Check what's derived from this entity
|
||||
derived = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, entity_id
|
||||
)
|
||||
|
||||
# Also check wasGeneratedBy (GraphRAG exploration uses this)
|
||||
generated = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_GENERATED_BY, entity_id
|
||||
)
|
||||
|
||||
all_children = derived + generated
|
||||
|
||||
for child_id in all_children:
|
||||
child_types = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=child_id, p=RDF_TYPE, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
for s, p, o in child_types:
|
||||
if o == TG_ANALYSIS or o == TG_CONCLUSION:
|
||||
return "agent"
|
||||
if o == TG_EXPLORATION:
|
||||
return "graphrag"
|
||||
|
||||
# Default to graphrag
|
||||
return "graphrag"
|
||||
|
||||
|
||||
def build_agent_trace(socket, flow_id, user, collection, session_id, api=None, max_answer=500):
|
||||
"""Build the full explainability trace for an agent session."""
|
||||
trace = {
|
||||
"session_id": session_id,
|
||||
"type": "agent",
|
||||
"question": None,
|
||||
"time": None,
|
||||
"iterations": [],
|
||||
"final_answer": None,
|
||||
}
|
||||
|
||||
# Get session metadata
|
||||
props = get_node_properties(socket, flow_id, user, collection, session_id)
|
||||
trace["question"] = props.get(TG_QUERY, [None])[0]
|
||||
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
|
||||
|
||||
# Find all entities derived from this session (iterations and final)
|
||||
# Start by looking for entities where prov:wasDerivedFrom = session_id
|
||||
current_uri = session_id
|
||||
iteration_num = 1
|
||||
|
||||
while True:
|
||||
# Find entities derived from current
|
||||
derived_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, current_uri
|
||||
)
|
||||
|
||||
if not derived_ids:
|
||||
break
|
||||
|
||||
derived_id = derived_ids[0]
|
||||
derived_props = get_node_properties(socket, flow_id, user, collection, derived_id)
|
||||
|
||||
# Check type
|
||||
types = derived_props.get(RDF_TYPE, [])
|
||||
|
||||
if TG_ANALYSIS in types:
|
||||
iteration = {
|
||||
"id": derived_id,
|
||||
"iteration_num": iteration_num,
|
||||
"thought": derived_props.get(TG_THOUGHT, [None])[0],
|
||||
"action": derived_props.get(TG_ACTION, [None])[0],
|
||||
"arguments": derived_props.get(TG_ARGUMENTS, [None])[0],
|
||||
"observation": derived_props.get(TG_OBSERVATION, [None])[0],
|
||||
}
|
||||
trace["iterations"].append(iteration)
|
||||
current_uri = derived_id
|
||||
iteration_num += 1
|
||||
|
||||
elif TG_CONCLUSION in types:
|
||||
answer = derived_props.get(TG_ANSWER, [None])[0]
|
||||
if answer and len(answer) > max_answer:
|
||||
answer = answer[:max_answer] + "... [truncated]"
|
||||
trace["final_answer"] = {
|
||||
"id": derived_id,
|
||||
"answer": answer,
|
||||
}
|
||||
break
|
||||
|
||||
else:
|
||||
# Unknown type, stop traversal
|
||||
break
|
||||
|
||||
return trace
|
||||
|
||||
|
||||
def print_agent_text(trace):
|
||||
"""Print agent trace in text format."""
|
||||
print(f"=== Agent Session: {trace['session_id']} ===")
|
||||
print(f"=== GraphRAG Session: {question.uri if question else 'Unknown'} ===")
|
||||
print()
|
||||
|
||||
if trace["question"]:
|
||||
print(f"Question: {trace['question']}")
|
||||
if trace["time"]:
|
||||
print(f"Time: {trace['time']}")
|
||||
print()
|
||||
|
||||
# Analysis steps
|
||||
print("--- Analysis ---")
|
||||
iterations = trace.get("iterations", [])
|
||||
if iterations:
|
||||
for iteration in iterations:
|
||||
print(f"Analysis {iteration['iteration_num']}:")
|
||||
print(f" Thought: {iteration.get('thought', 'N/A')}")
|
||||
print(f" Action: {iteration.get('action', 'N/A')}")
|
||||
|
||||
args = iteration.get('arguments')
|
||||
if args:
|
||||
# Try to pretty-print JSON arguments
|
||||
try:
|
||||
import json
|
||||
args_obj = json.loads(args)
|
||||
args_str = json.dumps(args_obj, indent=4)
|
||||
# Indent each line
|
||||
args_lines = args_str.split('\n')
|
||||
print(f" Arguments:")
|
||||
for line in args_lines:
|
||||
print(f" {line}")
|
||||
except:
|
||||
print(f" Arguments: {args}")
|
||||
else:
|
||||
print(f" Arguments: N/A")
|
||||
|
||||
obs = iteration.get('observation', 'N/A')
|
||||
if obs and len(obs) > 200:
|
||||
obs = obs[:200] + "... [truncated]"
|
||||
print(f" Observation: {obs}")
|
||||
print()
|
||||
else:
|
||||
print("No analysis steps recorded")
|
||||
print()
|
||||
|
||||
# Conclusion
|
||||
print("--- Conclusion ---")
|
||||
final = trace.get("final_answer")
|
||||
if final and final.get("answer"):
|
||||
print("Answer:")
|
||||
for line in final["answer"].split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No conclusion recorded")
|
||||
|
||||
|
||||
def print_agent_json(trace):
|
||||
"""Print agent trace as JSON."""
|
||||
print(json.dumps(trace, indent=2))
|
||||
|
||||
|
||||
def build_trace(socket, flow_id, user, collection, question_id, api=None, show_provenance=False, max_answer=500):
|
||||
"""Build the full explainability trace for a question."""
|
||||
label_cache = {}
|
||||
|
||||
trace = {
|
||||
"question_id": question_id,
|
||||
"question": None,
|
||||
"time": None,
|
||||
"exploration": None,
|
||||
"focus": None,
|
||||
"synthesis": None,
|
||||
}
|
||||
|
||||
# Get question metadata
|
||||
props = get_node_properties(socket, flow_id, user, collection, question_id)
|
||||
trace["question"] = props.get(TG_QUERY, [None])[0]
|
||||
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
|
||||
|
||||
# Find exploration: ?exploration prov:wasGeneratedBy question_id
|
||||
exploration_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_GENERATED_BY, question_id
|
||||
)
|
||||
|
||||
if exploration_ids:
|
||||
exploration_id = exploration_ids[0]
|
||||
exploration_props = get_node_properties(socket, flow_id, user, collection, exploration_id)
|
||||
trace["exploration"] = {
|
||||
"id": exploration_id,
|
||||
"edge_count": exploration_props.get(TG_EDGE_COUNT, [None])[0],
|
||||
}
|
||||
|
||||
# Find focus: ?focus prov:wasDerivedFrom exploration_id
|
||||
focus_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, exploration_id
|
||||
)
|
||||
|
||||
if focus_ids:
|
||||
focus_id = focus_ids[0]
|
||||
focus_props = get_node_properties(socket, flow_id, user, collection, focus_id)
|
||||
|
||||
# Get selected edges
|
||||
edge_selection_uris = focus_props.get(TG_SELECTED_EDGE, [])
|
||||
selected_edges = []
|
||||
|
||||
for edge_sel_uri in edge_selection_uris:
|
||||
edge_sel_props = get_node_properties(socket, flow_id, user, collection, edge_sel_uri)
|
||||
edge = edge_sel_props.get(TG_EDGE, [None])[0]
|
||||
reasoning = edge_sel_props.get(TG_REASONING, [None])[0]
|
||||
|
||||
edge_info = {
|
||||
"edge": edge,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
|
||||
# Trace provenance if requested
|
||||
if show_provenance and isinstance(edge, dict):
|
||||
provenance = trace_edge_provenance(
|
||||
socket, flow_id, user, collection,
|
||||
edge.get("s", ""), edge.get("p", ""), edge.get("o", ""),
|
||||
label_cache
|
||||
)
|
||||
edge_info["provenance"] = provenance
|
||||
|
||||
selected_edges.append(edge_info)
|
||||
|
||||
trace["focus"] = {
|
||||
"id": focus_id,
|
||||
"selected_edges": selected_edges,
|
||||
}
|
||||
|
||||
# Find synthesis: ?synthesis prov:wasDerivedFrom focus_id
|
||||
synthesis_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, focus_id
|
||||
)
|
||||
|
||||
if synthesis_ids:
|
||||
synthesis_id = synthesis_ids[0]
|
||||
synthesis_props = get_node_properties(socket, flow_id, user, collection, synthesis_id)
|
||||
|
||||
# Get content directly or via document reference
|
||||
content = synthesis_props.get(TG_CONTENT, [None])[0]
|
||||
doc_id = synthesis_props.get(TG_DOCUMENT, [None])[0]
|
||||
|
||||
if not content and doc_id and api:
|
||||
content = get_document_content(api, user, doc_id, max_answer)
|
||||
elif content and len(content) > max_answer:
|
||||
content = content[:max_answer] + "... [truncated]"
|
||||
|
||||
trace["synthesis"] = {
|
||||
"id": synthesis_id,
|
||||
"document_id": doc_id,
|
||||
"answer": content,
|
||||
}
|
||||
|
||||
# Store label cache for formatting
|
||||
trace["_label_cache"] = label_cache
|
||||
|
||||
return trace
|
||||
|
||||
|
||||
def print_text(trace, show_provenance=False):
|
||||
"""Print trace in text format."""
|
||||
label_cache = trace.get("_label_cache", {})
|
||||
|
||||
print(f"=== GraphRAG Session: {trace['question_id']} ===")
|
||||
print()
|
||||
|
||||
if trace["question"]:
|
||||
print(f"Question: {trace['question']}")
|
||||
if trace["time"]:
|
||||
print(f"Time: {trace['time']}")
|
||||
if question:
|
||||
print(f"Question: {question.query}")
|
||||
if question.timestamp:
|
||||
print(f"Time: {question.timestamp}")
|
||||
print()
|
||||
|
||||
# Exploration
|
||||
print("--- Exploration ---")
|
||||
exploration = trace.get("exploration")
|
||||
if exploration:
|
||||
edge_count = exploration.get("edge_count", "?")
|
||||
print(f"Retrieved {edge_count} edges from knowledge graph")
|
||||
print(f"Retrieved {exploration.edge_count} edges from knowledge graph")
|
||||
else:
|
||||
print("No exploration data found")
|
||||
print()
|
||||
|
|
@ -609,24 +188,28 @@ def print_text(trace, show_provenance=False):
|
|||
print("--- Focus (Edge Selection) ---")
|
||||
focus = trace.get("focus")
|
||||
if focus:
|
||||
edges = focus.get("selected_edges", [])
|
||||
edges = focus.edge_selections
|
||||
print(f"Selected {len(edges)} edges:")
|
||||
print()
|
||||
|
||||
for i, edge_info in enumerate(edges, 1):
|
||||
edge = edge_info.get("edge")
|
||||
reasoning = edge_info.get("reasoning")
|
||||
label_cache = {}
|
||||
|
||||
if edge:
|
||||
edge_str = format_edge(edge)
|
||||
print(f" {i}. {edge_str}")
|
||||
for i, edge_sel in enumerate(edges, 1):
|
||||
if edge_sel.edge:
|
||||
s_label, p_label, o_label = explain_client.resolve_edge_labels(
|
||||
edge_sel.edge, user, collection
|
||||
)
|
||||
print(f" {i}. ({s_label}, {p_label}, {o_label})")
|
||||
|
||||
if reasoning:
|
||||
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reasoning: {r_short}")
|
||||
|
||||
if show_provenance:
|
||||
provenance = edge_info.get("provenance", [])
|
||||
if show_provenance and edge_sel.edge:
|
||||
provenance = trace_edge_provenance(
|
||||
flow, user, collection, edge_sel.edge,
|
||||
label_cache, explain_client
|
||||
)
|
||||
for chain in provenance:
|
||||
chain_str = format_provenance_chain(chain)
|
||||
if chain_str:
|
||||
|
|
@ -641,11 +224,9 @@ def print_text(trace, show_provenance=False):
|
|||
print("--- Synthesis ---")
|
||||
synthesis = trace.get("synthesis")
|
||||
if synthesis:
|
||||
answer = synthesis.get("answer")
|
||||
if answer:
|
||||
if synthesis.content:
|
||||
print("Answer:")
|
||||
# Indent the answer
|
||||
for line in answer.split("\n"):
|
||||
for line in synthesis.content.split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No answer content found")
|
||||
|
|
@ -653,11 +234,173 @@ def print_text(trace, show_provenance=False):
|
|||
print("No synthesis data found")
|
||||
|
||||
|
||||
def print_json(trace):
|
||||
"""Print trace as JSON."""
|
||||
# Remove internal cache before printing
|
||||
output = {k: v for k, v in trace.items() if not k.startswith("_")}
|
||||
print(json.dumps(output, indent=2))
|
||||
def print_docrag_text(trace):
|
||||
"""Print DocRAG trace in text format."""
|
||||
question = trace.get("question")
|
||||
|
||||
print(f"=== DocRAG Session: {question.uri if question else 'Unknown'} ===")
|
||||
print()
|
||||
|
||||
if question:
|
||||
print(f"Question: {question.query}")
|
||||
if question.timestamp:
|
||||
print(f"Time: {question.timestamp}")
|
||||
print()
|
||||
|
||||
# Exploration
|
||||
print("--- Exploration ---")
|
||||
exploration = trace.get("exploration")
|
||||
if exploration:
|
||||
print(f"Retrieved {exploration.chunk_count} chunks from document store")
|
||||
else:
|
||||
print("No exploration data found")
|
||||
print()
|
||||
|
||||
# Synthesis (no Focus step for DocRAG)
|
||||
print("--- Synthesis ---")
|
||||
synthesis = trace.get("synthesis")
|
||||
if synthesis:
|
||||
if synthesis.content:
|
||||
print("Answer:")
|
||||
for line in synthesis.content.split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No answer content found")
|
||||
else:
|
||||
print("No synthesis data found")
|
||||
|
||||
|
||||
def print_agent_text(trace):
|
||||
"""Print Agent trace in text format."""
|
||||
question = trace.get("question")
|
||||
|
||||
print(f"=== Agent Session: {question.uri if question else 'Unknown'} ===")
|
||||
print()
|
||||
|
||||
if question:
|
||||
print(f"Question: {question.query}")
|
||||
if question.timestamp:
|
||||
print(f"Time: {question.timestamp}")
|
||||
print()
|
||||
|
||||
# Analysis steps
|
||||
print("--- Analysis ---")
|
||||
iterations = trace.get("iterations", [])
|
||||
if iterations:
|
||||
for i, analysis in enumerate(iterations, 1):
|
||||
print(f"Analysis {i}:")
|
||||
print(f" Thought: {analysis.thought or 'N/A'}")
|
||||
print(f" Action: {analysis.action or 'N/A'}")
|
||||
|
||||
if analysis.arguments:
|
||||
# Try to pretty-print JSON arguments
|
||||
try:
|
||||
args_obj = json.loads(analysis.arguments)
|
||||
args_str = json.dumps(args_obj, indent=4)
|
||||
print(f" Arguments:")
|
||||
for line in args_str.split('\n'):
|
||||
print(f" {line}")
|
||||
except Exception:
|
||||
print(f" Arguments: {analysis.arguments}")
|
||||
else:
|
||||
print(f" Arguments: N/A")
|
||||
|
||||
obs = analysis.observation or 'N/A'
|
||||
if obs and len(obs) > 200:
|
||||
obs = obs[:200] + "... [truncated]"
|
||||
print(f" Observation: {obs}")
|
||||
print()
|
||||
else:
|
||||
print("No analysis steps recorded")
|
||||
print()
|
||||
|
||||
# Conclusion
|
||||
print("--- Conclusion ---")
|
||||
conclusion = trace.get("conclusion")
|
||||
if conclusion and conclusion.answer:
|
||||
print("Answer:")
|
||||
for line in conclusion.answer.split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No conclusion recorded")
|
||||
|
||||
|
||||
def trace_to_dict(trace, trace_type):
|
||||
"""Convert trace entities to JSON-serializable dict."""
|
||||
if trace_type == "agent":
|
||||
question = trace.get("question")
|
||||
return {
|
||||
"type": "agent",
|
||||
"session_id": question.uri if question else None,
|
||||
"question": question.query if question else None,
|
||||
"time": question.timestamp if question else None,
|
||||
"iterations": [
|
||||
{
|
||||
"id": a.uri,
|
||||
"thought": a.thought,
|
||||
"action": a.action,
|
||||
"arguments": a.arguments,
|
||||
"observation": a.observation,
|
||||
}
|
||||
for a in trace.get("iterations", [])
|
||||
],
|
||||
"conclusion": {
|
||||
"id": trace["conclusion"].uri,
|
||||
"answer": trace["conclusion"].answer,
|
||||
} if trace.get("conclusion") else None,
|
||||
}
|
||||
elif trace_type == "docrag":
|
||||
question = trace.get("question")
|
||||
exploration = trace.get("exploration")
|
||||
synthesis = trace.get("synthesis")
|
||||
|
||||
return {
|
||||
"type": "docrag",
|
||||
"question_id": question.uri if question else None,
|
||||
"question": question.query if question else None,
|
||||
"time": question.timestamp if question else None,
|
||||
"exploration": {
|
||||
"id": exploration.uri,
|
||||
"chunk_count": exploration.chunk_count,
|
||||
} if exploration else None,
|
||||
"synthesis": {
|
||||
"id": synthesis.uri,
|
||||
"document_uri": synthesis.document_uri,
|
||||
"answer": synthesis.content,
|
||||
} if synthesis else None,
|
||||
}
|
||||
else:
|
||||
# graphrag
|
||||
question = trace.get("question")
|
||||
exploration = trace.get("exploration")
|
||||
focus = trace.get("focus")
|
||||
synthesis = trace.get("synthesis")
|
||||
|
||||
return {
|
||||
"type": "graphrag",
|
||||
"question_id": question.uri if question else None,
|
||||
"question": question.query if question else None,
|
||||
"time": question.timestamp if question else None,
|
||||
"exploration": {
|
||||
"id": exploration.uri,
|
||||
"edge_count": exploration.edge_count,
|
||||
} if exploration else None,
|
||||
"focus": {
|
||||
"id": focus.uri,
|
||||
"selected_edges": [
|
||||
{
|
||||
"edge": edge_sel.edge,
|
||||
"reasoning": edge_sel.reasoning,
|
||||
}
|
||||
for edge_sel in focus.edge_selections
|
||||
],
|
||||
} if focus else None,
|
||||
"synthesis": {
|
||||
"id": synthesis.uri,
|
||||
"document_uri": synthesis.document_uri,
|
||||
"answer": synthesis.content,
|
||||
} if synthesis else None,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -727,50 +470,69 @@ def main():
|
|||
try:
|
||||
api = Api(args.api_url, token=args.token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(args.flow_id)
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
try:
|
||||
# Detect trace type (agent vs graphrag)
|
||||
trace_type = detect_trace_type(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
# Detect trace type
|
||||
trace_type = explain_client.detect_session_type(
|
||||
args.question_id,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
entity_id=args.question_id,
|
||||
)
|
||||
|
||||
if trace_type == "agent":
|
||||
# Build and print agent trace
|
||||
trace = build_agent_trace(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
# Fetch and display agent trace
|
||||
trace = explain_client.fetch_agent_trace(
|
||||
args.question_id,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
session_id=args.question_id,
|
||||
api=api,
|
||||
max_answer=args.max_answer,
|
||||
max_content=args.max_answer,
|
||||
)
|
||||
|
||||
if args.format == 'json':
|
||||
print_agent_json(trace)
|
||||
print(json.dumps(trace_to_dict(trace, "agent"), indent=2))
|
||||
else:
|
||||
print_agent_text(trace)
|
||||
else:
|
||||
# Build and print GraphRAG trace (existing behavior)
|
||||
trace = build_trace(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
|
||||
elif trace_type == "docrag":
|
||||
# Fetch and display DocRAG trace
|
||||
trace = explain_client.fetch_docrag_trace(
|
||||
args.question_id,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
question_id=args.question_id,
|
||||
api=api,
|
||||
show_provenance=args.show_provenance,
|
||||
max_answer=args.max_answer,
|
||||
max_content=args.max_answer,
|
||||
)
|
||||
|
||||
if args.format == 'json':
|
||||
print_json(trace)
|
||||
print(json.dumps(trace_to_dict(trace, "docrag"), indent=2))
|
||||
else:
|
||||
print_text(trace, show_provenance=args.show_provenance)
|
||||
print_docrag_text(trace)
|
||||
|
||||
else:
|
||||
# Fetch and display GraphRAG trace
|
||||
trace = explain_client.fetch_graphrag_trace(
|
||||
args.question_id,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
api=api,
|
||||
max_content=args.max_answer,
|
||||
)
|
||||
|
||||
if args.format == 'json':
|
||||
print(json.dumps(trace_to_dict(trace, "graphrag"), indent=2))
|
||||
else:
|
||||
print_graphrag_text(
|
||||
trace, explain_client, flow,
|
||||
args.user, args.collection,
|
||||
show_provenance=args.show_provenance
|
||||
)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue