Add unified explainability support and librarian storage for (#693)

Add unified explainability support and librarian storage for all retrieval engines

Implements consistent explainability/provenance tracking
across GraphRAG, DocumentRAG, and Agent retrieval
engines. All large content (answers, thoughts, observations)
is now stored in librarian rather than as inline literals in
the knowledge graph.

Explainability API:
- New explainability.py module with entity classes (Question,
  Exploration, Focus, Synthesis, Analysis, Conclusion) and
  ExplainabilityClient
- Quiescence-based eventual consistency handling for trace
  fetching
- Content fetching from librarian with retry logic

CLI updates:
- tg-invoke-graph-rag -x/--explainable flag returns
  explain_id
- tg-invoke-document-rag -x/--explainable flag returns
  explain_id
- tg-invoke-agent -x/--explainable flag returns explain_id
- tg-list-explain-traces uses new explainability API
- tg-show-explain-trace handles all three trace types

Agent provenance:
- Records session, iterations (think/act/observe), and conclusion
- Stores thoughts and observations in librarian with document
  references
- New predicates: tg:thoughtDocument, tg:observationDocument

DocumentRAG provenance:
- Records question, exploration (chunk retrieval), and synthesis
- Stores answers in librarian with document references

Schema changes:
- AgentResponse: added explain_id, explain_graph fields
- RetrievalResponse: added explain_id, explain_graph fields
- agent_iteration_triples: supports thought_document_id,
  observation_document_id

Update tests.
This commit is contained in:
cybermaggedon 2026-03-12 21:40:09 +00:00 committed by GitHub
parent aecf00f040
commit 35128ff019
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 2736 additions and 846 deletions

View file

@ -110,16 +110,17 @@ class TestRAGTranslatorCompletionFlags:
assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is False
def test_document_rag_translator_is_final_with_end_of_stream_true(self):
def test_document_rag_translator_is_final_with_end_of_session_true(self):
"""
Test that DocumentRagResponseTranslator returns is_final=True
when end_of_stream=True.
when end_of_session=True.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("document-rag")
response = DocumentRagResponse(
response="A document about cats.",
end_of_stream=True,
end_of_session=True,
error=None
)
@ -127,9 +128,31 @@ class TestRAGTranslatorCompletionFlags:
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_stream=True"
assert is_final is True, "is_final must be True when end_of_session=True"
assert response_dict["response"] == "A document about cats."
assert response_dict["end_of_session"] is True
def test_document_rag_translator_end_of_stream_not_final(self):
"""
Test that end_of_stream=True alone does NOT make is_final=True.
The session continues with provenance messages after LLM stream completes.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("document-rag")
response = DocumentRagResponse(
response="Final chunk",
end_of_stream=True,
end_of_session=False, # Session continues with provenance
error=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is False
def test_document_rag_translator_is_final_with_end_of_stream_false(self):
"""

View file

@ -30,10 +30,13 @@ class TestAgentStructuredQueryIntegration:
pulsar_client=AsyncMock(),
max_iterations=3
)
# Mock the client method for structured query
proc.client = MagicMock()
# Mock librarian to avoid hanging on save operations
proc.save_answer_content = AsyncMock(return_value=None)
return proc
@pytest.fixture

View file

@ -28,6 +28,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10
)
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent
sent_responses = []
@ -106,6 +109,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10
)
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent
sent_responses = []
@ -173,6 +179,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10
)
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent
sent_responses = []

View file

@ -68,6 +68,7 @@ class TestDocumentRagService:
collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5,
explain_callback=ANY, # Explainability callback is always passed
save_answer_callback=ANY, # Librarian save callback is always passed
)
# Verify response was sent

View file

@ -59,7 +59,7 @@ from .flow import Flow, FlowInstance
from .async_flow import AsyncFlow, AsyncFlowInstance
# WebSocket clients
from .socket_client import SocketClient, SocketFlowInstance
from .socket_client import SocketClient, SocketFlowInstance, build_term
from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance
# Bulk operation clients
@ -70,6 +70,21 @@ from .async_bulk_client import AsyncBulkClient
from .metrics import Metrics
from .async_metrics import AsyncMetrics
# Explainability
from .explainability import (
ExplainabilityClient,
ExplainEntity,
Question,
Exploration,
Focus,
Synthesis,
Analysis,
Conclusion,
EdgeSelection,
wire_triples_to_tuples,
extract_term_value,
)
# Types
from .types import (
Triple,
@ -85,6 +100,7 @@ from .types import (
AgentObservation,
AgentAnswer,
RAGChunk,
ProvenanceEvent,
)
# Exceptions
@ -124,6 +140,7 @@ __all__ = [
"SocketFlowInstance",
"AsyncSocketClient",
"AsyncSocketFlowInstance",
"build_term",
# Bulk operation clients
"BulkClient",
@ -133,6 +150,19 @@ __all__ = [
"Metrics",
"AsyncMetrics",
# Explainability
"ExplainabilityClient",
"ExplainEntity",
"Question",
"Exploration",
"Focus",
"Synthesis",
"Analysis",
"Conclusion",
"EdgeSelection",
"wire_triples_to_tuples",
"extract_term_value",
# Types
"Triple",
"Uri",
@ -147,6 +177,7 @@ __all__ = [
"AgentObservation",
"AgentAnswer",
"RAGChunk",
"ProvenanceEvent",
# Exceptions
"ProtocolException",

File diff suppressed because it is too large Load diff

View file

@ -15,6 +15,63 @@ from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, Strea
from . exceptions import ProtocolException, raise_from_error_dict
def build_term(value: Any, term_type: Optional[str] = None,
datatype: Optional[str] = None, language: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
Build wire-format Term dict from a value.
Auto-detection rules (when term_type is None):
- Already a dict with 't' key -> return as-is (already a Term)
- Starts with http://, https://, urn: -> IRI
- Wrapped in <> (e.g., <http://...>) -> IRI (angle brackets stripped)
- Anything else -> literal
Args:
value: The term value (string, dict, or None)
term_type: One of 'iri', 'literal', or None for auto-detect
datatype: Datatype for literal objects (e.g., xsd:integer)
language: Language tag for literal objects (e.g., en)
Returns:
dict: Wire-format Term dict, or None if value is None
"""
if value is None:
return None
# If already a Term dict, return as-is
if isinstance(value, dict) and "t" in value:
return value
# Convert to string for processing
value = str(value)
# Auto-detect type if not specified
if term_type is None:
if value.startswith("<") and value.endswith(">") and not value.startswith("<<"):
# Angle-bracket wrapped IRI: <http://...>
value = value[1:-1] # Strip < and >
term_type = "iri"
elif value.startswith(("http://", "https://", "urn:")):
term_type = "iri"
else:
term_type = "literal"
if term_type == "iri":
# Strip angle brackets if present
if value.startswith("<") and value.endswith(">"):
value = value[1:-1]
return {"t": "i", "i": value}
elif term_type == "literal":
result = {"t": "l", "v": value}
if datatype:
result["dt"] = datatype
if language:
result["ln"] = language
return result
else:
raise ValueError(f"Unknown term type: {term_type}")
class SocketClient:
"""
Synchronous WebSocket client for streaming operations.
@ -92,7 +149,8 @@ class SocketClient:
flow: Optional[str],
request: Dict[str, Any],
streaming: bool = False,
streaming_raw: bool = False
streaming_raw: bool = False,
include_provenance: bool = False
) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]:
"""Synchronous wrapper around async WebSocket communication.
@ -119,7 +177,7 @@ class SocketClient:
return self._streaming_generator_raw(service, flow, request, loop)
elif streaming:
# Parsed streaming for agent/RAG chunk types
return self._streaming_generator(service, flow, request, loop)
return self._streaming_generator(service, flow, request, loop, include_provenance)
else:
# Non-streaming single response
return loop.run_until_complete(self._send_request_async(service, flow, request))
@ -129,10 +187,11 @@ class SocketClient:
service: str,
flow: Optional[str],
request: Dict[str, Any],
loop: asyncio.AbstractEventLoop
loop: asyncio.AbstractEventLoop,
include_provenance: bool = False
) -> Iterator[StreamingChunk]:
"""Generator that yields streaming chunks (for agent/RAG responses)"""
async_gen = self._send_request_async_streaming(service, flow, request)
async_gen = self._send_request_async_streaming(service, flow, request, include_provenance)
try:
while True:
@ -265,7 +324,8 @@ class SocketClient:
self,
service: str,
flow: Optional[str],
request: Dict[str, Any]
request: Dict[str, Any],
include_provenance: bool = False
) -> Iterator[StreamingChunk]:
"""Async implementation of WebSocket request (streaming)"""
# Generate unique request ID
@ -309,8 +369,8 @@ class SocketClient:
raise_from_error_dict(resp["error"])
# Parse different chunk types
chunk = self._parse_chunk(resp)
if chunk is not None: # Skip provenance messages in streaming
chunk = self._parse_chunk(resp, include_provenance=include_provenance)
if chunk is not None: # Skip provenance messages unless include_provenance
yield chunk
# Check if this is the final message
@ -325,14 +385,26 @@ class SocketClient:
chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type")
# Handle new GraphRAG message format with message_type
if message_type == "provenance":
# Handle GraphRAG/DocRAG message format with message_type
if message_type == "explain":
if include_provenance:
# Return provenance event for explainability
return ProvenanceEvent(provenance_id=resp.get("provenance_id", ""))
return ProvenanceEvent(
explain_id=resp.get("explain_id", ""),
explain_graph=resp.get("explain_graph", "")
)
# Provenance messages are not yielded to user - they're metadata
return None
# Handle Agent message format with chunk_type="explain"
if chunk_type == "explain":
if include_provenance:
return ProvenanceEvent(
explain_id=resp.get("explain_id", ""),
explain_graph=resp.get("explain_graph", "")
)
return None
if chunk_type == "thought":
return AgentThought(
content=resp.get("content", ""),
@ -477,6 +549,95 @@ class SocketFlowInstance:
# regardless of streaming flag, so always use the streaming code path
return self.client._send_request_sync("agent", self.flow_id, request, streaming=True)
def agent_explain(
self,
question: str,
user: str,
collection: str,
state: Optional[Dict[str, Any]] = None,
group: Optional[str] = None,
history: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any
) -> Iterator[Union[StreamingChunk, ProvenanceEvent]]:
"""
Execute an agent operation with explainability support.
Streams both content chunks (AgentThought, AgentObservation, AgentAnswer)
and provenance events (ProvenanceEvent). Provenance events contain URIs
that can be fetched using ExplainabilityClient to get detailed information
about the agent's reasoning process.
Agent trace consists of:
- Session: The initial question and session metadata
- Iterations: Each thought/action/observation cycle
- Conclusion: The final answer
Args:
question: User question or instruction
user: User identifier
collection: Collection identifier for provenance storage
state: Optional state dictionary for stateful conversations
group: Optional group identifier for multi-user contexts
history: Optional conversation history as list of message dicts
**kwargs: Additional parameters passed to the agent service
Yields:
Union[StreamingChunk, ProvenanceEvent]: Agent chunks and provenance events
Example:
```python
from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent
from trustgraph.api import AgentThought, AgentObservation, AgentAnswer
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
provenance_ids = []
for item in flow.agent_explain(
question="What is the capital of France?",
user="trustgraph",
collection="default"
):
if isinstance(item, AgentThought):
print(f"[Thought] {item.content}")
elif isinstance(item, AgentObservation):
print(f"[Observation] {item.content}")
elif isinstance(item, AgentAnswer):
print(f"[Answer] {item.content}")
elif isinstance(item, ProvenanceEvent):
provenance_ids.append(item.explain_id)
# Fetch session trace after completion
if provenance_ids:
trace = explain_client.fetch_agent_trace(
provenance_ids[0], # Session URI is first
graph="urn:graph:retrieval",
user="trustgraph",
collection="default"
)
```
"""
request = {
"question": question,
"user": user,
"collection": collection,
"streaming": True # Always streaming for explain
}
if state is not None:
request["state"] = state
if group is not None:
request["group"] = group
if history is not None:
request["history"] = history
request.update(kwargs)
# Use streaming with provenance enabled
return self.client._send_request_sync(
"agent", self.flow_id, request,
streaming=True, include_provenance=True
)
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
"""
Execute text completion with optional streaming.
@ -596,6 +757,86 @@ class SocketFlowInstance:
else:
return result.get("response", "")
def graph_rag_explain(
self,
query: str,
user: str,
collection: str,
max_subgraph_size: int = 1000,
max_subgraph_count: int = 5,
max_entity_distance: int = 3,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""
Execute graph-based RAG query with explainability support.
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
Provenance events contain URIs that can be fetched using ExplainabilityClient
to get detailed information about how the response was generated.
Args:
query: Natural language query
user: User/keyspace identifier
collection: Collection identifier
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
max_subgraph_count: Maximum number of subgraphs (default: 5)
max_entity_distance: Maximum traversal depth (default: 3)
**kwargs: Additional parameters passed to the service
Yields:
Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events
Example:
```python
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
provenance_ids = []
response_text = ""
for item in flow.graph_rag_explain(
query="Tell me about Marie Curie",
user="trustgraph",
collection="scientists"
):
if isinstance(item, RAGChunk):
response_text += item.content
print(item.content, end='', flush=True)
elif isinstance(item, ProvenanceEvent):
provenance_ids.append(item.provenance_id)
# Fetch explainability details
for prov_id in provenance_ids:
entity = explain_client.fetch_entity(
prov_id,
graph="urn:graph:retrieval",
user="trustgraph",
collection="scientists"
)
print(f"Entity: {entity}")
```
"""
request = {
"query": query,
"user": user,
"collection": collection,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"streaming": True,
"explainable": True, # Enable explainability mode
}
request.update(kwargs)
# Use streaming with provenance events included
return self.client._send_request_sync(
"graph-rag", self.flow_id, request,
streaming=True, include_provenance=True
)
def document_rag(
self,
query: str,
@ -654,6 +895,79 @@ class SocketFlowInstance:
else:
return result.get("response", "")
def document_rag_explain(
self,
query: str,
user: str,
collection: str,
doc_limit: int = 10,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""
Execute document-based RAG query with explainability support.
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
Provenance events contain URIs that can be fetched using ExplainabilityClient
to get detailed information about how the response was generated.
Document RAG trace consists of:
- Question: The user's query
- Exploration: Chunks retrieved from document store (chunk_count)
- Synthesis: The generated answer
Args:
query: Natural language query
user: User/keyspace identifier
collection: Collection identifier
doc_limit: Maximum document chunks to retrieve (default: 10)
**kwargs: Additional parameters passed to the service
Yields:
Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events
Example:
```python
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
for item in flow.document_rag_explain(
query="Summarize the key findings",
user="trustgraph",
collection="research-papers",
doc_limit=5
):
if isinstance(item, RAGChunk):
print(item.content, end='', flush=True)
elif isinstance(item, ProvenanceEvent):
# Fetch entity details
entity = explain_client.fetch_entity(
item.explain_id,
graph=item.explain_graph,
user="trustgraph",
collection="research-papers"
)
print(f"Event: {entity}", file=sys.stderr)
```
"""
request = {
"query": query,
"user": user,
"collection": collection,
"doc-limit": doc_limit,
"streaming": True,
"explainable": True,
}
request.update(kwargs)
# Use streaming with provenance events included
return self.client._send_request_sync(
"document-rag", self.flow_id, request,
streaming=True, include_provenance=True
)
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
"""Generator for RAG streaming (graph-rag and document-rag)"""
for chunk in result:
@ -831,28 +1145,30 @@ class SocketFlowInstance:
def triples_query(
self,
s: Optional[str] = None,
p: Optional[str] = None,
o: Optional[str] = None,
s: Optional[Union[str, Dict[str, Any]]] = None,
p: Optional[Union[str, Dict[str, Any]]] = None,
o: Optional[Union[str, Dict[str, Any]]] = None,
g: Optional[str] = None,
user: Optional[str] = None,
collection: Optional[str] = None,
limit: int = 100,
**kwargs: Any
) -> Dict[str, Any]:
) -> List[Dict[str, Any]]:
"""
Query knowledge graph triples using pattern matching.
Args:
s: Subject URI (optional, use None for wildcard)
p: Predicate URI (optional, use None for wildcard)
o: Object URI or Literal (optional, use None for wildcard)
s: Subject filter - URI string, Term dict, or None for wildcard
p: Predicate filter - URI string, Term dict, or None for wildcard
o: Object filter - URI/literal string, Term dict, or None for wildcard
g: Named graph filter - URI string or None for all graphs
user: User/keyspace identifier (optional)
collection: Collection identifier (optional)
limit: Maximum results to return (default: 100)
**kwargs: Additional parameters passed to the service
Returns:
dict: Query results with matching triples
List[Dict]: List of matching triples in wire format
Example:
```python
@ -860,33 +1176,54 @@ class SocketFlowInstance:
flow = socket.flow("default")
# Find all triples about a specific subject
result = flow.triples_query(
triples = flow.triples_query(
s="http://example.org/person/marie-curie",
user="trustgraph",
collection="scientists"
)
# Query with named graph filter
triples = flow.triples_query(
s="urn:trustgraph:session:abc123",
g="urn:graph:retrieval",
user="trustgraph",
collection="default"
)
```
"""
request = {"limit": limit}
if s is not None:
request["s"] = str(s)
if p is not None:
request["p"] = str(p)
if o is not None:
request["o"] = str(o)
# Build Term dicts for s/p/o (auto-converts strings)
s_term = build_term(s)
p_term = build_term(p)
o_term = build_term(o)
if s_term is not None:
request["s"] = s_term
if p_term is not None:
request["p"] = p_term
if o_term is not None:
request["o"] = o_term
if g is not None:
request["g"] = g
if user is not None:
request["user"] = user
if collection is not None:
request["collection"] = collection
request.update(kwargs)
return self.client._send_request_sync("triples", self.flow_id, request, False)
result = self.client._send_request_sync("triples", self.flow_id, request, False)
# Return the triples list from the response
if isinstance(result, dict) and "response" in result:
return result["response"]
return result
def triples_query_stream(
self,
s: Optional[str] = None,
p: Optional[str] = None,
o: Optional[str] = None,
s: Optional[Union[str, Dict[str, Any]]] = None,
p: Optional[Union[str, Dict[str, Any]]] = None,
o: Optional[Union[str, Dict[str, Any]]] = None,
g: Optional[str] = None,
user: Optional[str] = None,
collection: Optional[str] = None,
limit: int = 100,
@ -900,9 +1237,10 @@ class SocketFlowInstance:
and memory overhead for large result sets.
Args:
s: Subject URI (optional, use None for wildcard)
p: Predicate URI (optional, use None for wildcard)
o: Object URI or Literal (optional, use None for wildcard)
s: Subject filter - URI string, Term dict, or None for wildcard
p: Predicate filter - URI string, Term dict, or None for wildcard
o: Object filter - URI/literal string, Term dict, or None for wildcard
g: Named graph filter - URI string or None for all graphs
user: User/keyspace identifier (optional)
collection: Collection identifier (optional)
limit: Maximum results to return (default: 100)
@ -930,12 +1268,20 @@ class SocketFlowInstance:
"streaming": True,
"batch-size": batch_size,
}
if s is not None:
request["s"] = str(s)
if p is not None:
request["p"] = str(p)
if o is not None:
request["o"] = str(o)
# Build Term dicts for s/p/o (auto-converts strings)
s_term = build_term(s)
p_term = build_term(p)
o_term = build_term(o)
if s_term is not None:
request["s"] = s_term
if p_term is not None:
request["p"] = p_term
if o_term is not None:
request["o"] = o_term
if g is not None:
request["g"] = g
if user is not None:
request["user"] = user
if collection is not None:

View file

@ -212,19 +212,21 @@ class ProvenanceEvent:
Each event represents a provenance node created during query processing.
Attributes:
provenance_id: URI of the provenance node (e.g., urn:trustgraph:session:abc123)
event_type: Type of provenance event (session, retrieval, selection, answer)
explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123)
explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval)
event_type: Type of provenance event (question, exploration, focus, synthesis)
"""
provenance_id: str
event_type: str = "" # Derived from provenance_id (session, retrieval, selection, answer)
explain_id: str
explain_graph: str = ""
event_type: str = "" # Derived from explain_id
def __post_init__(self):
# Extract event type from provenance_id
if "session" in self.provenance_id:
self.event_type = "session"
elif "retrieval" in self.provenance_id:
self.event_type = "retrieval"
elif "selection" in self.provenance_id:
self.event_type = "selection"
elif "answer" in self.provenance_id:
self.event_type = "answer"
# Extract event type from explain_id
if "question" in self.explain_id:
self.event_type = "question"
elif "exploration" in self.explain_id:
self.event_type = "exploration"
elif "focus" in self.explain_id:
self.event_type = "focus"
elif "synthesis" in self.explain_id:
self.event_type = "synthesis"

View file

@ -59,6 +59,15 @@ class AgentResponseTranslator(MessageTranslator):
result["end_of_message"] = getattr(obj, "end_of_message", False)
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
# Include explainability fields if present
explain_id = getattr(obj, "explain_id", None)
if explain_id:
result["explain_id"] = explain_id
explain_graph = getattr(obj, "explain_graph", None)
if explain_graph is not None:
result["explain_graph"] = explain_graph
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "code": obj.error.code}

View file

@ -34,7 +34,12 @@ class DocumentRagResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
result = {}
# Include response content (even if empty string)
# Include message_type for distinguishing chunk vs explain messages
message_type = getattr(obj, "message_type", "")
if message_type:
result["message_type"] = message_type
# Include response content for chunk messages
if obj.response is not None:
result["response"] = obj.response
@ -48,9 +53,12 @@ class DocumentRagResponseTranslator(MessageTranslator):
if explain_graph is not None:
result["explain_graph"] = explain_graph
# Include end_of_stream flag
# Include end_of_stream flag (LLM stream complete)
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
# Include end_of_session flag (entire session complete)
result["end_of_session"] = getattr(obj, "end_of_session", False)
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
@ -59,7 +67,8 @@ class DocumentRagResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
is_final = getattr(obj, 'end_of_stream', False)
# Session is complete when end_of_session is True
is_final = getattr(obj, 'end_of_session', False)
return self.from_pulsar(obj), is_final

View file

@ -82,6 +82,10 @@ from . namespaces import (
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
# Agent provenance predicates
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
# Agent document references
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
# Document reference predicate
TG_DOCUMENT,
# Named graphs
GRAPH_DEFAULT, GRAPH_SOURCE, GRAPH_RETRIEVAL,
)
@ -165,6 +169,10 @@ __all__ = [
"TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION",
# Agent provenance predicates
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER",
# Agent document references
"TG_THOUGHT_DOCUMENT", "TG_OBSERVATION_DOCUMENT",
# Document reference predicate
"TG_DOCUMENT",
# Named graphs
"GRAPH_DEFAULT", "GRAPH_SOURCE", "GRAPH_RETRIEVAL",
# Triple builders

View file

@ -17,7 +17,8 @@ from . namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_AGENT_QUESTION,
)
@ -73,10 +74,12 @@ def agent_session_triples(
def agent_iteration_triples(
iteration_uri: str,
parent_uri: str,
thought: str,
action: str,
arguments: Dict[str, Any],
observation: str,
thought: str = "",
action: str = "",
arguments: Dict[str, Any] = None,
observation: str = "",
thought_document_id: Optional[str] = None,
observation_document_id: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for one agent iteration (Analysis - think/act/observe cycle).
@ -85,36 +88,53 @@ def agent_iteration_triples(
- Entity declaration with tg:Analysis type
- wasDerivedFrom link to parent (previous iteration or session)
- Thought, action, arguments, and observation data
- Document references for thought/observation when stored in librarian
Args:
iteration_uri: URI of this iteration (from agent_iteration_uri)
parent_uri: URI of the parent (previous iteration or session)
thought: The agent's reasoning/thought
thought: The agent's reasoning/thought (used if thought_document_id not provided)
action: The tool/action name
arguments: Arguments passed to the tool (will be JSON-encoded)
observation: The result/observation from the tool
observation: The result/observation from the tool (used if observation_document_id not provided)
thought_document_id: Optional document URI for thought in librarian (preferred)
observation_document_id: Optional document URI for observation in librarian (preferred)
Returns:
List of Triple objects
"""
if arguments is None:
arguments = {}
triples = [
_triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)),
_triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")),
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
_triple(iteration_uri, TG_THOUGHT, _literal(thought)),
_triple(iteration_uri, TG_ACTION, _literal(action)),
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
_triple(iteration_uri, TG_OBSERVATION, _literal(observation)),
]
# Thought: use document reference or inline
if thought_document_id:
triples.append(_triple(iteration_uri, TG_THOUGHT_DOCUMENT, _iri(thought_document_id)))
elif thought:
triples.append(_triple(iteration_uri, TG_THOUGHT, _literal(thought)))
# Observation: use document reference or inline
if observation_document_id:
triples.append(_triple(iteration_uri, TG_OBSERVATION_DOCUMENT, _iri(observation_document_id)))
elif observation:
triples.append(_triple(iteration_uri, TG_OBSERVATION, _literal(observation)))
return triples
def agent_final_triples(
final_uri: str,
parent_uri: str,
answer: str,
answer: str = "",
document_id: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for an agent final answer (Conclusion).
@ -122,20 +142,29 @@ def agent_final_triples(
Creates:
- Entity declaration with tg:Conclusion type
- wasDerivedFrom link to parent (last iteration or session)
- The answer text
- Either document reference (if document_id provided) or inline answer
Args:
final_uri: URI of the final answer (from agent_final_uri)
parent_uri: URI of the parent (last iteration or session if no iterations)
answer: The final answer text
answer: The final answer text (used if document_id not provided)
document_id: Optional document URI in librarian (preferred)
Returns:
List of Triple objects
"""
return [
triples = [
_triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)),
_triple(final_uri, RDFS_LABEL, _literal("Conclusion")),
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
_triple(final_uri, TG_ANSWER, _literal(answer)),
]
if document_id:
# Store reference to document in librarian (as IRI)
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
elif answer:
# Fallback: store inline answer
triples.append(_triple(final_uri, TG_ANSWER, _literal(answer)))
return triples

View file

@ -92,6 +92,10 @@ TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
# Agent document references (for librarian storage)
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
# Named graph URIs for RDF datasets
# These separate different types of data while keeping them in the same collection
GRAPH_DEFAULT = "" # Core knowledge facts (triples extracted from documents)

View file

@ -30,11 +30,15 @@ class AgentRequest:
@dataclass
class AgentResponse:
# Streaming-first design
chunk_type: str = "" # "thought", "action", "observation", "answer", "error"
chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error"
content: str = "" # The actual content (interpretation depends on chunk_type)
end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete
end_of_dialog: bool = False # Entire agent dialog is complete
# Explainability fields
explain_id: str | None = None # Provenance URI (announced as created)
explain_graph: str | None = None # Named graph where explain was stored
# Legacy fields (deprecated but kept for backward compatibility)
answer: str = ""
error: Error | None = None

View file

@ -43,6 +43,8 @@ class DocumentRagQuery:
class DocumentRagResponse:
error: Error | None = None
response: str | None = ""
end_of_stream: bool = False
end_of_stream: bool = False # LLM response stream complete
explain_id: str | None = None # Single explain URI (announced as created)
explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval)
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete

View file

@ -4,8 +4,19 @@ Uses the agent service to answer a question
import argparse
import os
import sys
import textwrap
from trustgraph.api import Api
from trustgraph.api import (
Api,
ExplainabilityClient,
ProvenanceEvent,
Question,
Analysis,
Conclusion,
AgentThought,
AgentObservation,
AgentAnswer,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -97,11 +108,148 @@ def output(text, prefix="> ", width=78):
)
print(out)
def question_explainable(
url, question_text, flow_id, user, collection,
state=None, group=None, verbose=False, token=None, debug=False
):
"""Execute agent with explainability - shows provenance events inline."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Track last chunk type for formatting
last_chunk_type = None
current_outputter = None
# Stream agent with explainability - process events as they arrive
for item in flow.agent_explain(
question=question_text,
user=user,
collection=collection,
state=state,
group=group,
):
if isinstance(item, AgentThought):
if last_chunk_type != "thought":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print() # Blank line between message types
if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f914 ")
current_outputter.__enter__()
last_chunk_type = "thought"
if current_outputter:
current_outputter.output(item.content)
if current_outputter.word_buffer:
print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = ""
elif isinstance(item, AgentObservation):
if last_chunk_type != "observation":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print()
if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
current_outputter.__enter__()
last_chunk_type = "observation"
if current_outputter:
current_outputter.output(item.content)
if current_outputter.word_buffer:
print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = ""
elif isinstance(item, AgentAnswer):
if last_chunk_type != "answer":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print()
last_chunk_type = "answer"
# Print answer content directly
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [session] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Analysis):
print(f"\n [iteration] {prov_id}", file=sys.stderr)
if entity.thought:
thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought
print(f" Thought: {thought_short}", file=sys.stderr)
if entity.action:
print(f" Action: {entity.action}", file=sys.stderr)
elif isinstance(entity, Conclusion):
print(f"\n [conclusion] {prov_id}", file=sys.stderr)
if entity.answer:
print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
# Close any remaining outputter
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
# Final newline if we ended with answer
if last_chunk_type == "answer":
print()
finally:
socket.close()
def question(
url, question, flow_id, user, collection,
plan=None, state=None, group=None, verbose=False, streaming=True,
token=None
token=None, explainable=False, debug=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
question_explainable(
url=url,
question_text=question,
flow_id=flow_id,
user=user,
collection=collection,
state=state,
group=group,
verbose=verbose,
token=token,
debug=debug
)
return
if verbose:
output(wrap(question), "\U00002753 ")
@ -270,6 +418,18 @@ def main():
help=f'Disable streaming (use legacy mode)'
)
parser.add_argument(
'-x', '--explainable',
action='store_true',
help='Show provenance events: Session, Iterations, Conclusion (implies streaming)'
)
parser.add_argument(
'--debug',
action='store_true',
help='Show debug output for troubleshooting'
)
args = parser.parse_args()
try:
@ -286,6 +446,8 @@ def main():
verbose = args.verbose,
streaming = not args.no_streaming,
token = args.token,
explainable = args.explainable,
debug = args.debug,
)
except Exception as e:

View file

@ -4,7 +4,16 @@ Uses the DocumentRAG service to answer a question
import argparse
import os
from trustgraph.api import Api
import sys
from trustgraph.api import (
Api,
ExplainabilityClient,
RAGChunk,
ProvenanceEvent,
Question,
Exploration,
Synthesis,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -12,7 +21,90 @@ default_user = 'trustgraph'
default_collection = 'default'
default_doc_limit = 10
def question(url, flow_id, question, user, collection, doc_limit, streaming=True, token=None):
def question_explainable(
url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False
):
"""Execute document RAG with explainability - shows provenance events inline."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Stream DocumentRAG with explainability - process events as they arrive
for item in flow.document_rag_explain(
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
):
if isinstance(item, RAGChunk):
# Print response content
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [question] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.chunk_count:
print(f" Chunks retrieved: {entity.chunk_count}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
print() # Final newline
finally:
socket.close()
def question(
url, flow_id, question_text, user, collection, doc_limit,
streaming=True, token=None, explainable=False, debug=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
question_explainable(
url=url,
flow_id=flow_id,
question_text=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
token=token,
debug=debug
)
return
# Create API client
api = Api(url=url, token=token)
@ -24,7 +116,7 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
try:
response = flow.document_rag(
query=question,
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
@ -42,13 +134,14 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
# Use REST API for non-streaming
flow = api.flow().id(flow_id)
resp = flow.document_rag(
query=question,
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
)
print(resp)
def main():
parser = argparse.ArgumentParser(
@ -105,6 +198,18 @@ def main():
help='Disable streaming (use non-streaming mode)'
)
parser.add_argument(
'-x', '--explainable',
action='store_true',
help='Show provenance events: Question, Exploration, Synthesis (implies streaming)'
)
parser.add_argument(
'--debug',
action='store_true',
help='Show debug output for troubleshooting'
)
args = parser.parse_args()
try:
@ -112,12 +217,14 @@ def main():
question(
url=args.url,
flow_id=args.flow_id,
question=args.question,
question_text=args.question,
user=args.user,
collection=args.collection,
doc_limit=args.doc_limit,
streaming=not args.no_streaming,
token=args.token,
explainable=args.explainable,
debug=args.debug,
)
except Exception as e:

View file

@ -8,7 +8,16 @@ import os
import sys
import websockets
import asyncio
from trustgraph.api import Api
from trustgraph.api import (
Api,
ExplainabilityClient,
RAGChunk,
ProvenanceEvent,
Question,
Exploration,
Focus,
Synthesis,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -602,18 +611,111 @@ async def _question_explainable(
print() # Final newline
def _question_explainable_api(
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
):
"""Execute graph RAG with explainability using the new API classes."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Stream GraphRAG with explainability - process events as they arrive
for item in flow.graph_rag_explain(
query=question_text,
user=user,
collection=collection,
max_subgraph_size=max_subgraph_size,
max_subgraph_count=5,
max_entity_distance=max_path_length,
):
if isinstance(item, RAGChunk):
# Print response content
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [question] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.edge_count:
print(f" Edges explored: {entity.edge_count}", file=sys.stderr)
elif isinstance(entity, Focus):
print(f"\n [focus] {prov_id}", file=sys.stderr)
if entity.selected_edge_uris:
print(f" Focused on {len(entity.selected_edge_uris)} edge(s)", file=sys.stderr)
# Fetch full focus with edge details
focus_full = explain_client.fetch_focus_with_edges(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if focus_full and focus_full.edge_selections:
for edge_sel in focus_full.edge_selections:
if edge_sel.edge:
# Resolve labels for edge components
s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, user, collection
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reason: {r_short}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
print() # Final newline
finally:
socket.close()
def question(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, streaming=True, token=None,
explainable=False, debug=False
):
# Explainable mode uses direct websocket to capture provenance events
# Explainable mode uses the API to capture and process provenance events
if explainable:
asyncio.run(_question_explainable(
_question_explainable_api(
url=url,
flow_id=flow_id,
question=question,
question_text=question,
user=user,
collection=collection,
entity_limit=entity_limit,
@ -622,7 +724,7 @@ def question(
max_path_length=max_path_length,
token=token,
debug=debug
))
)
return
# Create API client

View file

@ -14,180 +14,17 @@ import json
import os
import sys
from tabulate import tabulate
from trustgraph.api import Api
from trustgraph.api import Api, ExplainabilityClient
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph'
default_collection = 'default'
# Predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_QUESTION = TG + "Question"
TG_ANALYSIS = TG + "Analysis"
TG_EXPLORATION = TG + "Exploration"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
# Retrieval graph
RETRIEVAL_GRAPH = "urn:graph:retrieval"
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
"""Query triples using the socket API."""
request = {
"user": user,
"collection": collection,
"limit": limit,
"streaming": False,
}
if s is not None:
request["s"] = {"t": "i", "i": s}
if p is not None:
request["p"] = {"t": "i", "i": p}
if o is not None:
if isinstance(o, str):
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
request["o"] = {"t": "i", "i": o}
else:
request["o"] = {"t": "l", "v": o}
elif isinstance(o, dict):
request["o"] = o
if g is not None:
request["g"] = g
triples = []
try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
if isinstance(response, dict):
triple_list = response.get("response", response.get("triples", []))
else:
triple_list = response
if not isinstance(triple_list, list):
triple_list = [triple_list] if triple_list else []
for t in triple_list:
s_val = extract_value(t.get("s", {}))
p_val = extract_value(t.get("p", {}))
o_val = extract_value(t.get("o", {}))
triples.append((s_val, p_val, o_val))
except Exception as e:
print(f"Error querying triples: {e}", file=sys.stderr)
return triples
def extract_value(term):
"""Extract value from a term dict."""
if not term:
return ""
t = term.get("t") or term.get("type")
if t == "i":
return term.get("i") or term.get("iri", "")
elif t == "l":
return term.get("v") or term.get("value", "")
elif t == "t":
# Quoted triple
tr = term.get("tr") or term.get("triple", {})
return {
"s": extract_value(tr.get("s", {})),
"p": extract_value(tr.get("p", {})),
"o": extract_value(tr.get("o", {})),
}
# Fallback for raw values
if "i" in term:
return term["i"]
if "v" in term:
return term["v"]
return str(term)
def get_timestamp(socket, flow_id, user, collection, question_id):
"""Get timestamp for a question."""
triples = query_triples(
socket, flow_id, user, collection,
s=question_id, p=PROV_STARTED_AT_TIME, g=RETRIEVAL_GRAPH
)
for s, p, o in triples:
return o
return ""
def get_session_type(socket, flow_id, user, collection, session_id):
"""
Get the type of session (Agent or GraphRAG).
Both have tg:Question type, so we distinguish by URI pattern
or by checking what's derived from it.
"""
# Fast path: check URI pattern
if session_id.startswith("urn:trustgraph:agent:"):
return "Agent"
if session_id.startswith("urn:trustgraph:question:"):
return "GraphRAG"
# Check what's derived from this entity
derived = query_triples(
socket, flow_id, user, collection,
p=PROV_WAS_DERIVED_FROM, o=session_id, g=RETRIEVAL_GRAPH
)
generated = query_triples(
socket, flow_id, user, collection,
p=PROV_WAS_GENERATED_BY, o=session_id, g=RETRIEVAL_GRAPH
)
for s, p, o in derived + generated:
child_types = query_triples(
socket, flow_id, user, collection,
s=s, p=RDF_TYPE, g=RETRIEVAL_GRAPH
)
for _, _, child_type in child_types:
if child_type == TG_ANALYSIS:
return "Agent"
if child_type == TG_EXPLORATION:
return "GraphRAG"
return "GraphRAG"
def list_sessions(socket, flow_id, user, collection, limit):
"""List all explainability sessions (GraphRAG and Agent) by finding questions."""
# Query for all triples with predicate = tg:query
triples = query_triples(
socket, flow_id, user, collection,
p=TG_QUERY, g=RETRIEVAL_GRAPH, limit=limit
)
sessions = []
for question_id, _, query_text in triples:
# Get timestamp if available
timestamp = get_timestamp(socket, flow_id, user, collection, question_id)
# Get session type (Agent or GraphRAG)
session_type = get_session_type(socket, flow_id, user, collection, question_id)
sessions.append({
"id": question_id,
"type": session_type,
"question": query_text,
"time": timestamp,
})
# Sort by timestamp (newest first) if available
sessions.sort(key=lambda x: x.get("time", ""), reverse=True)
return sessions
def truncate_text(text, max_len=60):
"""Truncate text to max length with ellipsis."""
if not text:
@ -277,16 +114,42 @@ def main():
try:
api = Api(args.api_url, token=args.token)
socket = api.socket()
flow = socket.flow(args.flow_id)
explain_client = ExplainabilityClient(flow)
try:
sessions = list_sessions(
socket=socket,
flow_id=args.flow_id,
# List all sessions using the API
questions = explain_client.list_sessions(
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
limit=args.limit,
)
# Convert to output format
sessions = []
for q in questions:
session_type = explain_client.detect_session_type(
q.uri,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection
)
# Map type names
type_display = {
"graphrag": "GraphRAG",
"docrag": "DocRAG",
"agent": "Agent",
}.get(session_type, session_type.title())
sessions.append({
"id": q.uri,
"type": type_display,
"question": q.query,
"time": q.timestamp,
})
if args.format == 'json':
print_json(sessions)
else:

View file

@ -291,42 +291,25 @@ def query_graph(
):
"""Query the triple store with pattern matching.
Uses the WebSocket API's raw streaming mode for efficient delivery of results.
Uses the API's triples_query_stream for efficient streaming delivery.
"""
socket = Api(url, token=token).socket()
# Build request dict directly (bypassing triples_query_stream's string conversion)
request = {
"user": user,
"collection": collection,
"limit": limit,
"streaming": True,
"batch-size": batch_size,
}
# Add term dicts for s/p/o (None means wildcard)
if subject is not None:
request["s"] = subject
if predicate is not None:
request["p"] = predicate
if obj is not None:
request["o"] = obj
if graph is not None:
request["g"] = graph
flow = socket.flow(flow_id)
all_triples = []
try:
# Use raw streaming mode - yields response dicts directly
for response in socket._send_request_sync(
"triples", flow_id, request, streaming_raw=True
# Use triples_query_stream - accepts Term dicts directly
for triples in flow.triples_query_stream(
s=subject,
p=predicate,
o=obj,
g=graph,
user=user,
collection=collection,
limit=limit,
batch_size=batch_size,
):
# Response may have triples in different locations depending on format
if isinstance(response, dict):
triples = response.get("response", response.get("triples", []))
else:
triples = response
if not isinstance(triples, list):
triples = [triples] if triples else []

View file

@ -18,228 +18,99 @@ import argparse
import json
import os
import sys
from trustgraph.api import Api
from trustgraph.api import (
Api,
ExplainabilityClient,
Question,
Exploration,
Focus,
Synthesis,
Analysis,
Conclusion,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph'
default_collection = 'default'
# Predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document"
TG_REIFIES = TG + "reifies"
# Explainability entity types
TG_QUESTION = TG + "Question"
TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion"
# Agent predicates
TG_THOUGHT = TG + "thought"
TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
# Graphs
RETRIEVAL_GRAPH = "urn:graph:retrieval"
SOURCE_GRAPH = "urn:graph:source"
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
"""Query triples using the socket API."""
request = {
"user": user,
"collection": collection,
"limit": limit,
"streaming": False,
}
if s is not None:
request["s"] = {"t": "i", "i": s}
if p is not None:
request["p"] = {"t": "i", "i": p}
if o is not None:
if isinstance(o, str):
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
request["o"] = {"t": "i", "i": o}
else:
request["o"] = {"t": "l", "v": o}
elif isinstance(o, dict):
request["o"] = o
if g is not None:
request["g"] = g
triples = []
try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
if isinstance(response, dict):
triple_list = response.get("response", response.get("triples", []))
else:
triple_list = response
if not isinstance(triple_list, list):
triple_list = [triple_list] if triple_list else []
for t in triple_list:
s_val = extract_value(t.get("s", {}))
p_val = extract_value(t.get("p", {}))
o_val = extract_value(t.get("o", {}))
triples.append((s_val, p_val, o_val))
except Exception as e:
print(f"Error querying triples: {e}", file=sys.stderr)
return triples
# Provenance predicates for edge tracing
TG = "https://trustgraph.ai/ns/"
TG_REIFIES = TG + "reifies"
PROV = "http://www.w3.org/ns/prov#"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
def extract_value(term):
"""Extract value from a term dict."""
if not term:
return ""
def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client):
"""
Trace an edge back to its source document via reification.
t = term.get("t") or term.get("type")
Args:
flow: SocketFlowInstance
user: User identifier
collection: Collection identifier
edge: Dict with s, p, o keys
label_cache: Dict for caching labels
explain_client: ExplainabilityClient for label resolution
if t == "i":
return term.get("i") or term.get("iri", "")
elif t == "l":
return term.get("v") or term.get("value", "")
elif t == "t":
# Quoted triple
tr = term.get("tr") or term.get("triple", {})
return {
"s": extract_value(tr.get("s", {})),
"p": extract_value(tr.get("p", {})),
"o": extract_value(tr.get("o", {})),
}
Returns:
List of provenance chains, each chain is list of {uri, label}
"""
edge_s = edge.get("s", "")
edge_p = edge.get("p", "")
edge_o = edge.get("o", "")
# Fallback for raw values
if "i" in term:
return term["i"]
if "v" in term:
return term["v"]
# Build quoted triple for lookup
def build_term(val):
if isinstance(val, str) and (val.startswith("http") or val.startswith("urn:")):
return {"t": "i", "i": val}
return {"t": "l", "v": str(val)}
return str(term)
def get_node_properties(socket, flow_id, user, collection, node_uri, graph=RETRIEVAL_GRAPH):
"""Get all properties of a node as a dict."""
triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=graph)
props = {}
for s, p, o in triples:
if p not in props:
props[p] = []
props[p].append(o)
return props
def find_by_predicate_object(socket, flow_id, user, collection, predicate, obj, graph=RETRIEVAL_GRAPH):
"""Find subjects where predicate = obj."""
triples = query_triples(socket, flow_id, user, collection, p=predicate, o=obj, g=graph)
return [s for s, p, o in triples]
def get_label(socket, flow_id, user, collection, uri, label_cache):
"""Get label for a URI, with caching."""
if not isinstance(uri, str) or not (uri.startswith("http://") or uri.startswith("https://") or uri.startswith("urn:")):
return uri
if uri in label_cache:
return label_cache[uri]
triples = query_triples(socket, flow_id, user, collection, s=uri, p=RDFS_LABEL)
for s, p, o in triples:
label_cache[uri] = o
return o
label_cache[uri] = uri
return uri
def get_document_content(api, user, doc_id, max_content):
"""Fetch document content from librarian API."""
try:
library = api.library()
content = library.get_document_content(user=user, id=doc_id)
# Try to decode as text
try:
text = content.decode('utf-8')
if len(text) > max_content:
return text[:max_content] + "... [truncated]"
return text
except UnicodeDecodeError:
return f"[Binary: {len(content)} bytes]"
except Exception as e:
return f"[Error fetching content: {e}]"
def trace_edge_provenance(socket, flow_id, user, collection, edge_s, edge_p, edge_o, label_cache):
"""Trace an edge back to its source document via reification."""
# Build the quoted triple for lookup
quoted_triple = {
"t": "t",
"tr": {
"s": {"t": "i", "i": edge_s} if isinstance(edge_s, str) and (edge_s.startswith("http") or edge_s.startswith("urn:")) else {"t": "l", "v": edge_s},
"p": {"t": "i", "i": edge_p},
"o": {"t": "i", "i": edge_o} if isinstance(edge_o, str) and (edge_o.startswith("http") or edge_o.startswith("urn:")) else {"t": "l", "v": edge_o},
"s": build_term(edge_s),
"p": build_term(edge_p),
"o": build_term(edge_o),
}
}
# Query: ?stmt tg:reifies <<edge>>
request = {
"user": user,
"collection": collection,
"limit": 10,
"streaming": False,
"p": {"t": "i", "i": TG_REIFIES},
"o": quoted_triple,
"g": SOURCE_GRAPH,
}
stmt_uris = []
try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
if isinstance(response, dict):
triple_list = response.get("response", response.get("triples", []))
else:
triple_list = response
if not isinstance(triple_list, list):
triple_list = [triple_list] if triple_list else []
for t in triple_list:
s_val = extract_value(t.get("s", {}))
if s_val:
stmt_uris.append(s_val)
results = flow.triples_query(
p=TG_REIFIES,
o=quoted_triple,
g=SOURCE_GRAPH,
user=user,
collection=collection,
limit=10
)
except Exception:
pass
return []
# For each statement, find wasDerivedFrom chain
# Extract statement URIs
stmt_uris = []
for t in results:
s_term = t.get("s", {})
s_val = s_term.get("i") or s_term.get("v", "")
if s_val:
stmt_uris.append(s_val)
# For each statement, trace wasDerivedFrom chain
provenance_chains = []
for stmt_uri in stmt_uris:
chain = trace_provenance_chain(socket, flow_id, user, collection, stmt_uri, label_cache)
chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client)
if chain:
provenance_chains.append(chain)
return provenance_chains
def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_cache, max_depth=10):
def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10):
"""Trace prov:wasDerivedFrom chain from start_uri to root."""
chain = []
current = start_uri
@ -248,17 +119,32 @@ def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_c
if not current:
break
label = get_label(socket, flow_id, user, collection, current, label_cache)
# Get label
if current in label_cache:
label = label_cache[current]
else:
label = explain_client.resolve_label(current, user, collection)
label_cache[current] = label
chain.append({"uri": current, "label": label})
# Get parent
triples = query_triples(
socket, flow_id, user, collection,
s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH
)
# Get parent via wasDerivedFrom
try:
results = flow.triples_query(
s=current,
p=PROV_WAS_DERIVED_FROM,
g=SOURCE_GRAPH,
user=user,
collection=collection,
limit=1
)
except Exception:
break
parent = None
for s, p, o in triples:
parent = o
for t in results:
o_term = t.get("o", {})
parent = o_term.get("i") or o_term.get("v", "")
break
if not parent or parent == current:
@ -276,331 +162,24 @@ def format_provenance_chain(chain):
return " -> ".join(labels)
def format_edge(edge, label_cache=None, socket=None, flow_id=None, user=None, collection=None):
"""Format a quoted triple edge for display."""
if not isinstance(edge, dict):
return str(edge)
def print_graphrag_text(trace, explain_client, flow, user, collection, show_provenance=False):
"""Print GraphRAG trace in text format."""
question = trace.get("question")
s = edge.get("s", "?")
p = edge.get("p", "?")
o = edge.get("o", "?")
# Get labels if available
if label_cache and socket:
s_label = get_label(socket, flow_id, user, collection, s, label_cache)
p_label = get_label(socket, flow_id, user, collection, p, label_cache)
o_label = get_label(socket, flow_id, user, collection, o, label_cache)
else:
# Shorten URIs for display
s_label = s.split("/")[-1] if "/" in str(s) else s
p_label = p.split("/")[-1] if "/" in str(p) else p
o_label = o.split("/")[-1] if "/" in str(o) else o
return f"({s_label}, {p_label}, {o_label})"
def detect_trace_type(socket, flow_id, user, collection, entity_id):
"""
Detect whether an entity is an agent Question or GraphRAG Question.
Both have rdf:type = tg:Question, so we distinguish by checking
what's derived from it:
- Agent: has tg:Analysis or tg:Conclusion derived
- GraphRAG: has tg:Exploration derived
Also checks URI pattern as fallback:
- urn:trustgraph:agent: -> agent
- urn:trustgraph:question: -> graphrag
Returns:
"agent" or "graphrag"
"""
# Check URI pattern first (fast path)
if entity_id.startswith("urn:trustgraph:agent:"):
return "agent"
if entity_id.startswith("urn:trustgraph:question:"):
return "graphrag"
# Check what's derived from this entity
derived = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, entity_id
)
# Also check wasGeneratedBy (GraphRAG exploration uses this)
generated = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_GENERATED_BY, entity_id
)
all_children = derived + generated
for child_id in all_children:
child_types = query_triples(
socket, flow_id, user, collection,
s=child_id, p=RDF_TYPE, g=RETRIEVAL_GRAPH
)
for s, p, o in child_types:
if o == TG_ANALYSIS or o == TG_CONCLUSION:
return "agent"
if o == TG_EXPLORATION:
return "graphrag"
# Default to graphrag
return "graphrag"
def build_agent_trace(socket, flow_id, user, collection, session_id, api=None, max_answer=500):
"""Build the full explainability trace for an agent session."""
trace = {
"session_id": session_id,
"type": "agent",
"question": None,
"time": None,
"iterations": [],
"final_answer": None,
}
# Get session metadata
props = get_node_properties(socket, flow_id, user, collection, session_id)
trace["question"] = props.get(TG_QUERY, [None])[0]
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
# Find all entities derived from this session (iterations and final)
# Start by looking for entities where prov:wasDerivedFrom = session_id
current_uri = session_id
iteration_num = 1
while True:
# Find entities derived from current
derived_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, current_uri
)
if not derived_ids:
break
derived_id = derived_ids[0]
derived_props = get_node_properties(socket, flow_id, user, collection, derived_id)
# Check type
types = derived_props.get(RDF_TYPE, [])
if TG_ANALYSIS in types:
iteration = {
"id": derived_id,
"iteration_num": iteration_num,
"thought": derived_props.get(TG_THOUGHT, [None])[0],
"action": derived_props.get(TG_ACTION, [None])[0],
"arguments": derived_props.get(TG_ARGUMENTS, [None])[0],
"observation": derived_props.get(TG_OBSERVATION, [None])[0],
}
trace["iterations"].append(iteration)
current_uri = derived_id
iteration_num += 1
elif TG_CONCLUSION in types:
answer = derived_props.get(TG_ANSWER, [None])[0]
if answer and len(answer) > max_answer:
answer = answer[:max_answer] + "... [truncated]"
trace["final_answer"] = {
"id": derived_id,
"answer": answer,
}
break
else:
# Unknown type, stop traversal
break
return trace
def print_agent_text(trace):
"""Print agent trace in text format."""
print(f"=== Agent Session: {trace['session_id']} ===")
print(f"=== GraphRAG Session: {question.uri if question else 'Unknown'} ===")
print()
if trace["question"]:
print(f"Question: {trace['question']}")
if trace["time"]:
print(f"Time: {trace['time']}")
print()
# Analysis steps
print("--- Analysis ---")
iterations = trace.get("iterations", [])
if iterations:
for iteration in iterations:
print(f"Analysis {iteration['iteration_num']}:")
print(f" Thought: {iteration.get('thought', 'N/A')}")
print(f" Action: {iteration.get('action', 'N/A')}")
args = iteration.get('arguments')
if args:
# Try to pretty-print JSON arguments
try:
import json
args_obj = json.loads(args)
args_str = json.dumps(args_obj, indent=4)
# Indent each line
args_lines = args_str.split('\n')
print(f" Arguments:")
for line in args_lines:
print(f" {line}")
except:
print(f" Arguments: {args}")
else:
print(f" Arguments: N/A")
obs = iteration.get('observation', 'N/A')
if obs and len(obs) > 200:
obs = obs[:200] + "... [truncated]"
print(f" Observation: {obs}")
print()
else:
print("No analysis steps recorded")
print()
# Conclusion
print("--- Conclusion ---")
final = trace.get("final_answer")
if final and final.get("answer"):
print("Answer:")
for line in final["answer"].split("\n"):
print(f" {line}")
else:
print("No conclusion recorded")
def print_agent_json(trace):
"""Print agent trace as JSON."""
print(json.dumps(trace, indent=2))
def build_trace(socket, flow_id, user, collection, question_id, api=None, show_provenance=False, max_answer=500):
"""Build the full explainability trace for a question."""
label_cache = {}
trace = {
"question_id": question_id,
"question": None,
"time": None,
"exploration": None,
"focus": None,
"synthesis": None,
}
# Get question metadata
props = get_node_properties(socket, flow_id, user, collection, question_id)
trace["question"] = props.get(TG_QUERY, [None])[0]
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
# Find exploration: ?exploration prov:wasGeneratedBy question_id
exploration_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_GENERATED_BY, question_id
)
if exploration_ids:
exploration_id = exploration_ids[0]
exploration_props = get_node_properties(socket, flow_id, user, collection, exploration_id)
trace["exploration"] = {
"id": exploration_id,
"edge_count": exploration_props.get(TG_EDGE_COUNT, [None])[0],
}
# Find focus: ?focus prov:wasDerivedFrom exploration_id
focus_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, exploration_id
)
if focus_ids:
focus_id = focus_ids[0]
focus_props = get_node_properties(socket, flow_id, user, collection, focus_id)
# Get selected edges
edge_selection_uris = focus_props.get(TG_SELECTED_EDGE, [])
selected_edges = []
for edge_sel_uri in edge_selection_uris:
edge_sel_props = get_node_properties(socket, flow_id, user, collection, edge_sel_uri)
edge = edge_sel_props.get(TG_EDGE, [None])[0]
reasoning = edge_sel_props.get(TG_REASONING, [None])[0]
edge_info = {
"edge": edge,
"reasoning": reasoning,
}
# Trace provenance if requested
if show_provenance and isinstance(edge, dict):
provenance = trace_edge_provenance(
socket, flow_id, user, collection,
edge.get("s", ""), edge.get("p", ""), edge.get("o", ""),
label_cache
)
edge_info["provenance"] = provenance
selected_edges.append(edge_info)
trace["focus"] = {
"id": focus_id,
"selected_edges": selected_edges,
}
# Find synthesis: ?synthesis prov:wasDerivedFrom focus_id
synthesis_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, focus_id
)
if synthesis_ids:
synthesis_id = synthesis_ids[0]
synthesis_props = get_node_properties(socket, flow_id, user, collection, synthesis_id)
# Get content directly or via document reference
content = synthesis_props.get(TG_CONTENT, [None])[0]
doc_id = synthesis_props.get(TG_DOCUMENT, [None])[0]
if not content and doc_id and api:
content = get_document_content(api, user, doc_id, max_answer)
elif content and len(content) > max_answer:
content = content[:max_answer] + "... [truncated]"
trace["synthesis"] = {
"id": synthesis_id,
"document_id": doc_id,
"answer": content,
}
# Store label cache for formatting
trace["_label_cache"] = label_cache
return trace
def print_text(trace, show_provenance=False):
"""Print trace in text format."""
label_cache = trace.get("_label_cache", {})
print(f"=== GraphRAG Session: {trace['question_id']} ===")
print()
if trace["question"]:
print(f"Question: {trace['question']}")
if trace["time"]:
print(f"Time: {trace['time']}")
if question:
print(f"Question: {question.query}")
if question.timestamp:
print(f"Time: {question.timestamp}")
print()
# Exploration
print("--- Exploration ---")
exploration = trace.get("exploration")
if exploration:
edge_count = exploration.get("edge_count", "?")
print(f"Retrieved {edge_count} edges from knowledge graph")
print(f"Retrieved {exploration.edge_count} edges from knowledge graph")
else:
print("No exploration data found")
print()
@ -609,24 +188,28 @@ def print_text(trace, show_provenance=False):
print("--- Focus (Edge Selection) ---")
focus = trace.get("focus")
if focus:
edges = focus.get("selected_edges", [])
edges = focus.edge_selections
print(f"Selected {len(edges)} edges:")
print()
for i, edge_info in enumerate(edges, 1):
edge = edge_info.get("edge")
reasoning = edge_info.get("reasoning")
label_cache = {}
if edge:
edge_str = format_edge(edge)
print(f" {i}. {edge_str}")
for i, edge_sel in enumerate(edges, 1):
if edge_sel.edge:
s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, user, collection
)
print(f" {i}. ({s_label}, {p_label}, {o_label})")
if reasoning:
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reasoning: {r_short}")
if show_provenance:
provenance = edge_info.get("provenance", [])
if show_provenance and edge_sel.edge:
provenance = trace_edge_provenance(
flow, user, collection, edge_sel.edge,
label_cache, explain_client
)
for chain in provenance:
chain_str = format_provenance_chain(chain)
if chain_str:
@ -641,11 +224,9 @@ def print_text(trace, show_provenance=False):
print("--- Synthesis ---")
synthesis = trace.get("synthesis")
if synthesis:
answer = synthesis.get("answer")
if answer:
if synthesis.content:
print("Answer:")
# Indent the answer
for line in answer.split("\n"):
for line in synthesis.content.split("\n"):
print(f" {line}")
else:
print("No answer content found")
@ -653,11 +234,173 @@ def print_text(trace, show_provenance=False):
print("No synthesis data found")
def print_json(trace):
"""Print trace as JSON."""
# Remove internal cache before printing
output = {k: v for k, v in trace.items() if not k.startswith("_")}
print(json.dumps(output, indent=2))
def print_docrag_text(trace):
"""Print DocRAG trace in text format."""
question = trace.get("question")
print(f"=== DocRAG Session: {question.uri if question else 'Unknown'} ===")
print()
if question:
print(f"Question: {question.query}")
if question.timestamp:
print(f"Time: {question.timestamp}")
print()
# Exploration
print("--- Exploration ---")
exploration = trace.get("exploration")
if exploration:
print(f"Retrieved {exploration.chunk_count} chunks from document store")
else:
print("No exploration data found")
print()
# Synthesis (no Focus step for DocRAG)
print("--- Synthesis ---")
synthesis = trace.get("synthesis")
if synthesis:
if synthesis.content:
print("Answer:")
for line in synthesis.content.split("\n"):
print(f" {line}")
else:
print("No answer content found")
else:
print("No synthesis data found")
def print_agent_text(trace):
"""Print Agent trace in text format."""
question = trace.get("question")
print(f"=== Agent Session: {question.uri if question else 'Unknown'} ===")
print()
if question:
print(f"Question: {question.query}")
if question.timestamp:
print(f"Time: {question.timestamp}")
print()
# Analysis steps
print("--- Analysis ---")
iterations = trace.get("iterations", [])
if iterations:
for i, analysis in enumerate(iterations, 1):
print(f"Analysis {i}:")
print(f" Thought: {analysis.thought or 'N/A'}")
print(f" Action: {analysis.action or 'N/A'}")
if analysis.arguments:
# Try to pretty-print JSON arguments
try:
args_obj = json.loads(analysis.arguments)
args_str = json.dumps(args_obj, indent=4)
print(f" Arguments:")
for line in args_str.split('\n'):
print(f" {line}")
except Exception:
print(f" Arguments: {analysis.arguments}")
else:
print(f" Arguments: N/A")
obs = analysis.observation or 'N/A'
if obs and len(obs) > 200:
obs = obs[:200] + "... [truncated]"
print(f" Observation: {obs}")
print()
else:
print("No analysis steps recorded")
print()
# Conclusion
print("--- Conclusion ---")
conclusion = trace.get("conclusion")
if conclusion and conclusion.answer:
print("Answer:")
for line in conclusion.answer.split("\n"):
print(f" {line}")
else:
print("No conclusion recorded")
def trace_to_dict(trace, trace_type):
"""Convert trace entities to JSON-serializable dict."""
if trace_type == "agent":
question = trace.get("question")
return {
"type": "agent",
"session_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"iterations": [
{
"id": a.uri,
"thought": a.thought,
"action": a.action,
"arguments": a.arguments,
"observation": a.observation,
}
for a in trace.get("iterations", [])
],
"conclusion": {
"id": trace["conclusion"].uri,
"answer": trace["conclusion"].answer,
} if trace.get("conclusion") else None,
}
elif trace_type == "docrag":
question = trace.get("question")
exploration = trace.get("exploration")
synthesis = trace.get("synthesis")
return {
"type": "docrag",
"question_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"exploration": {
"id": exploration.uri,
"chunk_count": exploration.chunk_count,
} if exploration else None,
"synthesis": {
"id": synthesis.uri,
"document_uri": synthesis.document_uri,
"answer": synthesis.content,
} if synthesis else None,
}
else:
# graphrag
question = trace.get("question")
exploration = trace.get("exploration")
focus = trace.get("focus")
synthesis = trace.get("synthesis")
return {
"type": "graphrag",
"question_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"exploration": {
"id": exploration.uri,
"edge_count": exploration.edge_count,
} if exploration else None,
"focus": {
"id": focus.uri,
"selected_edges": [
{
"edge": edge_sel.edge,
"reasoning": edge_sel.reasoning,
}
for edge_sel in focus.edge_selections
],
} if focus else None,
"synthesis": {
"id": synthesis.uri,
"document_uri": synthesis.document_uri,
"answer": synthesis.content,
} if synthesis else None,
}
def main():
@ -727,50 +470,69 @@ def main():
try:
api = Api(args.api_url, token=args.token)
socket = api.socket()
flow = socket.flow(args.flow_id)
explain_client = ExplainabilityClient(flow)
try:
# Detect trace type (agent vs graphrag)
trace_type = detect_trace_type(
socket=socket,
flow_id=args.flow_id,
# Detect trace type
trace_type = explain_client.detect_session_type(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
entity_id=args.question_id,
)
if trace_type == "agent":
# Build and print agent trace
trace = build_agent_trace(
socket=socket,
flow_id=args.flow_id,
# Fetch and display agent trace
trace = explain_client.fetch_agent_trace(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
session_id=args.question_id,
api=api,
max_answer=args.max_answer,
max_content=args.max_answer,
)
if args.format == 'json':
print_agent_json(trace)
print(json.dumps(trace_to_dict(trace, "agent"), indent=2))
else:
print_agent_text(trace)
else:
# Build and print GraphRAG trace (existing behavior)
trace = build_trace(
socket=socket,
flow_id=args.flow_id,
elif trace_type == "docrag":
# Fetch and display DocRAG trace
trace = explain_client.fetch_docrag_trace(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
question_id=args.question_id,
api=api,
show_provenance=args.show_provenance,
max_answer=args.max_answer,
max_content=args.max_answer,
)
if args.format == 'json':
print_json(trace)
print(json.dumps(trace_to_dict(trace, "docrag"), indent=2))
else:
print_text(trace, show_provenance=args.show_provenance)
print_docrag_text(trace)
else:
# Fetch and display GraphRAG trace
trace = explain_client.fetch_graphrag_trace(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
api=api,
max_content=args.max_answer,
)
if args.format == 'json':
print(json.dumps(trace_to_dict(trace, "graphrag"), indent=2))
else:
print_graphrag_text(
trace, explain_client, flow,
args.user, args.collection,
show_provenance=args.show_provenance
)
finally:
socket.close()

View file

@ -2,6 +2,8 @@
Simple agent infrastructure broadly implements the ReAct flow.
"""
import asyncio
import base64
import json
import re
import sys
@ -17,9 +19,13 @@ from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... base import ProducerSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
# Provenance imports for agent explainability
from trustgraph.provenance import (
@ -41,6 +47,8 @@ from . types import Final, Action, Tool, Argument
default_ident = "agent-manager"
default_max_iterations = 10
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(AgentService):
@ -129,6 +137,115 @@ class Processor(AgentService):
)
)
# Librarian client for storing answer content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_tools_config(self, config, version):
logger.info(f"Loading configuration version {version}")
@ -347,6 +464,15 @@ class Processor(AgentService):
))
logger.debug(f"Emitted session triples for {session_uri}")
# Send explain event for session
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=session_uri,
explain_graph=GRAPH_RETRIEVAL,
))
logger.info(f"Question: {request.question}")
if len(history) >= self.max_iterations:
@ -504,8 +630,28 @@ class Processor(AgentService):
else:
parent_uri = session_uri
# Save answer to librarian
answer_doc_id = None
if f:
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
try:
await self.save_answer_content(
doc_id=answer_doc_id,
user=request.user,
content=f,
title=f"Agent Answer: {request.question[:50]}...",
)
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
answer_doc_id = None # Fall back to inline content
final_triples = set_graph(
agent_final_triples(final_uri, parent_uri, f),
agent_final_triples(
final_uri, parent_uri,
answer="" if answer_doc_id else f,
document_id=answer_doc_id,
),
GRAPH_RETRIEVAL
)
await flow("explainability").send(Triples(
@ -518,6 +664,15 @@ class Processor(AgentService):
))
logger.debug(f"Emitted final triples for {final_uri}")
# Send explain event for conclusion
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=final_uri,
explain_graph=GRAPH_RETRIEVAL,
))
if streaming:
# Streaming format - send end-of-dialog marker
# Answer chunks were already sent via answer() callback during parsing
@ -558,14 +713,48 @@ class Processor(AgentService):
else:
parent_uri = session_uri
# Save thought to librarian
thought_doc_id = None
if act.thought:
thought_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
try:
await self.save_answer_content(
doc_id=thought_doc_id,
user=request.user,
content=act.thought,
title=f"Agent Thought: {act.name}",
)
logger.debug(f"Saved thought to librarian: {thought_doc_id}")
except Exception as e:
logger.warning(f"Failed to save thought to librarian: {e}")
thought_doc_id = None
# Save observation to librarian
observation_doc_id = None
if act.observation:
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
try:
await self.save_answer_content(
doc_id=observation_doc_id,
user=request.user,
content=act.observation,
title=f"Agent Observation: {act.name}",
)
logger.debug(f"Saved observation to librarian: {observation_doc_id}")
except Exception as e:
logger.warning(f"Failed to save observation to librarian: {e}")
observation_doc_id = None
iter_triples = set_graph(
agent_iteration_triples(
iteration_uri,
parent_uri,
act.thought,
act.name,
act.arguments,
act.observation,
thought="" if thought_doc_id else act.thought,
action=act.name,
arguments=act.arguments,
observation="" if observation_doc_id else act.observation,
thought_document_id=thought_doc_id,
observation_document_id=observation_doc_id,
),
GRAPH_RETRIEVAL
)
@ -579,6 +768,15 @@ class Processor(AgentService):
))
logger.debug(f"Emitted iteration triples for {iteration_uri}")
# Send explain event for iteration
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=iteration_uri,
explain_graph=GRAPH_RETRIEVAL,
))
history.append(act)
# Handle state transitions if tool execution was successful

View file

@ -109,7 +109,7 @@ class DocumentRag:
async def query(
self, query, user="trustgraph", collection="default",
doc_limit=20, streaming=False, chunk_callback=None,
explain_callback=None,
explain_callback=None, save_answer_callback=None,
):
"""
Execute a Document RAG query with optional explainability tracking.
@ -122,6 +122,7 @@ class DocumentRag:
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for explainability
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns:
str: The synthesized answer text
@ -192,9 +193,28 @@ class DocumentRag:
# Emit synthesis explainability after answer generated
if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian if callback provided
if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None # Fall back to inline content
# Generate triples with document reference or inline content
syn_triples = set_graph(
docrag_synthesis_triples(syn_uri, exp_uri, answer_text),
docrag_synthesis_triples(
syn_uri, exp_uri,
answer_text="" if synthesis_doc_id else answer_text,
document_id=synthesis_doc_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(syn_triples, syn_uri)

View file

@ -8,8 +8,10 @@ import asyncio
import base64
import logging
import uuid
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from ... schema import LibrarianRequest, LibrarianResponse
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples, Metadata
from ... provenance import GRAPH_RETRIEVAL
@ -179,6 +181,62 @@ class Processor(FlowProcessor):
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "DocumentRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow):
try:
@ -222,10 +280,20 @@ class Processor(FlowProcessor):
response=None,
explain_id=explain_id,
explain_graph=GRAPH_RETRIEVAL,
message_type="explain",
),
properties={"id": id}
)
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
content=answer_text,
title=f"DocumentRAG Answer: {v.query[:50]}...",
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
@ -235,6 +303,7 @@ class Processor(FlowProcessor):
DocumentRagResponse(
response=chunk,
end_of_stream=end_of_stream,
message_type="chunk",
error=None
),
properties={"id": id}
@ -250,6 +319,17 @@ class Processor(FlowProcessor):
streaming=True,
chunk_callback=send_chunk,
explain_callback=send_explainability,
save_answer_callback=save_answer,
)
# Send end_of_session to signal entire session is complete
await flow("response").send(
DocumentRagResponse(
response=None,
end_of_session=True,
message_type="end",
),
properties={"id": id}
)
else:
# Non-streaming path (existing behavior)
@ -259,6 +339,7 @@ class Processor(FlowProcessor):
collection=v.collection,
doc_limit=doc_limit,
explain_callback=send_explainability,
save_answer_callback=save_answer,
)
await flow("response").send(