Add unified explainability support and librarian storage for (#693)

Add unified explainability support and librarian storage for all retrieval engines

Implements consistent explainability/provenance tracking
across GraphRAG, DocumentRAG, and Agent retrieval
engines. All large content (answers, thoughts, observations)
is now stored in librarian rather than as inline literals in
the knowledge graph.

Explainability API:
- New explainability.py module with entity classes (Question,
  Exploration, Focus, Synthesis, Analysis, Conclusion) and
  ExplainabilityClient
- Quiescence-based eventual consistency handling for trace
  fetching
- Content fetching from librarian with retry logic

CLI updates:
- tg-invoke-graph-rag -x/--explainable flag returns
  explain_id
- tg-invoke-document-rag -x/--explainable flag returns
  explain_id
- tg-invoke-agent -x/--explainable flag returns explain_id
- tg-list-explain-traces uses new explainability API
- tg-show-explain-trace handles all three trace types

Agent provenance:
- Records session, iterations (think/act/observe), and conclusion
- Stores thoughts and observations in librarian with document
  references
- New predicates: tg:thoughtDocument, tg:observationDocument

DocumentRAG provenance:
- Records question, exploration (chunk retrieval), and synthesis
- Stores answers in librarian with document references

Schema changes:
- AgentResponse: added explain_id, explain_graph fields
- RetrievalResponse: added explain_id, explain_graph fields
- agent_iteration_triples: supports thought_document_id,
  observation_document_id

Update tests.
This commit is contained in:
cybermaggedon 2026-03-12 21:40:09 +00:00 committed by GitHub
parent aecf00f040
commit 35128ff019
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 2736 additions and 846 deletions

View file

@ -110,16 +110,17 @@ class TestRAGTranslatorCompletionFlags:
assert response_dict["end_of_stream"] is True assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is False assert response_dict["end_of_session"] is False
def test_document_rag_translator_is_final_with_end_of_stream_true(self): def test_document_rag_translator_is_final_with_end_of_session_true(self):
""" """
Test that DocumentRagResponseTranslator returns is_final=True Test that DocumentRagResponseTranslator returns is_final=True
when end_of_stream=True. when end_of_session=True.
""" """
# Arrange # Arrange
translator = TranslatorRegistry.get_response_translator("document-rag") translator = TranslatorRegistry.get_response_translator("document-rag")
response = DocumentRagResponse( response = DocumentRagResponse(
response="A document about cats.", response="A document about cats.",
end_of_stream=True, end_of_stream=True,
end_of_session=True,
error=None error=None
) )
@ -127,9 +128,31 @@ class TestRAGTranslatorCompletionFlags:
response_dict, is_final = translator.from_response_with_completion(response) response_dict, is_final = translator.from_response_with_completion(response)
# Assert # Assert
assert is_final is True, "is_final must be True when end_of_stream=True" assert is_final is True, "is_final must be True when end_of_session=True"
assert response_dict["response"] == "A document about cats." assert response_dict["response"] == "A document about cats."
assert response_dict["end_of_session"] is True
def test_document_rag_translator_end_of_stream_not_final(self):
"""
Test that end_of_stream=True alone does NOT make is_final=True.
The session continues with provenance messages after LLM stream completes.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("document-rag")
response = DocumentRagResponse(
response="Final chunk",
end_of_stream=True,
end_of_session=False, # Session continues with provenance
error=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
assert response_dict["end_of_stream"] is True assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is False
def test_document_rag_translator_is_final_with_end_of_stream_false(self): def test_document_rag_translator_is_final_with_end_of_stream_false(self):
""" """

View file

@ -30,10 +30,13 @@ class TestAgentStructuredQueryIntegration:
pulsar_client=AsyncMock(), pulsar_client=AsyncMock(),
max_iterations=3 max_iterations=3
) )
# Mock the client method for structured query # Mock the client method for structured query
proc.client = MagicMock() proc.client = MagicMock()
# Mock librarian to avoid hanging on save operations
proc.save_answer_content = AsyncMock(return_value=None)
return proc return proc
@pytest.fixture @pytest.fixture

View file

@ -28,6 +28,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10 max_iterations=10
) )
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent # Track all responses sent
sent_responses = [] sent_responses = []
@ -106,6 +109,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10 max_iterations=10
) )
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent # Track all responses sent
sent_responses = [] sent_responses = []
@ -173,6 +179,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10 max_iterations=10
) )
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent # Track all responses sent
sent_responses = [] sent_responses = []

View file

@ -68,6 +68,7 @@ class TestDocumentRagService:
collection="test_coll_1", # Must be from message, not hardcoded default collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5, doc_limit=5,
explain_callback=ANY, # Explainability callback is always passed explain_callback=ANY, # Explainability callback is always passed
save_answer_callback=ANY, # Librarian save callback is always passed
) )
# Verify response was sent # Verify response was sent

View file

@ -59,7 +59,7 @@ from .flow import Flow, FlowInstance
from .async_flow import AsyncFlow, AsyncFlowInstance from .async_flow import AsyncFlow, AsyncFlowInstance
# WebSocket clients # WebSocket clients
from .socket_client import SocketClient, SocketFlowInstance from .socket_client import SocketClient, SocketFlowInstance, build_term
from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance
# Bulk operation clients # Bulk operation clients
@ -70,6 +70,21 @@ from .async_bulk_client import AsyncBulkClient
from .metrics import Metrics from .metrics import Metrics
from .async_metrics import AsyncMetrics from .async_metrics import AsyncMetrics
# Explainability
from .explainability import (
ExplainabilityClient,
ExplainEntity,
Question,
Exploration,
Focus,
Synthesis,
Analysis,
Conclusion,
EdgeSelection,
wire_triples_to_tuples,
extract_term_value,
)
# Types # Types
from .types import ( from .types import (
Triple, Triple,
@ -85,6 +100,7 @@ from .types import (
AgentObservation, AgentObservation,
AgentAnswer, AgentAnswer,
RAGChunk, RAGChunk,
ProvenanceEvent,
) )
# Exceptions # Exceptions
@ -124,6 +140,7 @@ __all__ = [
"SocketFlowInstance", "SocketFlowInstance",
"AsyncSocketClient", "AsyncSocketClient",
"AsyncSocketFlowInstance", "AsyncSocketFlowInstance",
"build_term",
# Bulk operation clients # Bulk operation clients
"BulkClient", "BulkClient",
@ -133,6 +150,19 @@ __all__ = [
"Metrics", "Metrics",
"AsyncMetrics", "AsyncMetrics",
# Explainability
"ExplainabilityClient",
"ExplainEntity",
"Question",
"Exploration",
"Focus",
"Synthesis",
"Analysis",
"Conclusion",
"EdgeSelection",
"wire_triples_to_tuples",
"extract_term_value",
# Types # Types
"Triple", "Triple",
"Uri", "Uri",
@ -147,6 +177,7 @@ __all__ = [
"AgentObservation", "AgentObservation",
"AgentAnswer", "AgentAnswer",
"RAGChunk", "RAGChunk",
"ProvenanceEvent",
# Exceptions # Exceptions
"ProtocolException", "ProtocolException",

File diff suppressed because it is too large Load diff

View file

@ -15,6 +15,63 @@ from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, Strea
from . exceptions import ProtocolException, raise_from_error_dict from . exceptions import ProtocolException, raise_from_error_dict
def build_term(value: Any, term_type: Optional[str] = None,
datatype: Optional[str] = None, language: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
Build wire-format Term dict from a value.
Auto-detection rules (when term_type is None):
- Already a dict with 't' key -> return as-is (already a Term)
- Starts with http://, https://, urn: -> IRI
- Wrapped in <> (e.g., <http://...>) -> IRI (angle brackets stripped)
- Anything else -> literal
Args:
value: The term value (string, dict, or None)
term_type: One of 'iri', 'literal', or None for auto-detect
datatype: Datatype for literal objects (e.g., xsd:integer)
language: Language tag for literal objects (e.g., en)
Returns:
dict: Wire-format Term dict, or None if value is None
"""
if value is None:
return None
# If already a Term dict, return as-is
if isinstance(value, dict) and "t" in value:
return value
# Convert to string for processing
value = str(value)
# Auto-detect type if not specified
if term_type is None:
if value.startswith("<") and value.endswith(">") and not value.startswith("<<"):
# Angle-bracket wrapped IRI: <http://...>
value = value[1:-1] # Strip < and >
term_type = "iri"
elif value.startswith(("http://", "https://", "urn:")):
term_type = "iri"
else:
term_type = "literal"
if term_type == "iri":
# Strip angle brackets if present
if value.startswith("<") and value.endswith(">"):
value = value[1:-1]
return {"t": "i", "i": value}
elif term_type == "literal":
result = {"t": "l", "v": value}
if datatype:
result["dt"] = datatype
if language:
result["ln"] = language
return result
else:
raise ValueError(f"Unknown term type: {term_type}")
class SocketClient: class SocketClient:
""" """
Synchronous WebSocket client for streaming operations. Synchronous WebSocket client for streaming operations.
@ -92,7 +149,8 @@ class SocketClient:
flow: Optional[str], flow: Optional[str],
request: Dict[str, Any], request: Dict[str, Any],
streaming: bool = False, streaming: bool = False,
streaming_raw: bool = False streaming_raw: bool = False,
include_provenance: bool = False
) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]: ) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]:
"""Synchronous wrapper around async WebSocket communication. """Synchronous wrapper around async WebSocket communication.
@ -119,7 +177,7 @@ class SocketClient:
return self._streaming_generator_raw(service, flow, request, loop) return self._streaming_generator_raw(service, flow, request, loop)
elif streaming: elif streaming:
# Parsed streaming for agent/RAG chunk types # Parsed streaming for agent/RAG chunk types
return self._streaming_generator(service, flow, request, loop) return self._streaming_generator(service, flow, request, loop, include_provenance)
else: else:
# Non-streaming single response # Non-streaming single response
return loop.run_until_complete(self._send_request_async(service, flow, request)) return loop.run_until_complete(self._send_request_async(service, flow, request))
@ -129,10 +187,11 @@ class SocketClient:
service: str, service: str,
flow: Optional[str], flow: Optional[str],
request: Dict[str, Any], request: Dict[str, Any],
loop: asyncio.AbstractEventLoop loop: asyncio.AbstractEventLoop,
include_provenance: bool = False
) -> Iterator[StreamingChunk]: ) -> Iterator[StreamingChunk]:
"""Generator that yields streaming chunks (for agent/RAG responses)""" """Generator that yields streaming chunks (for agent/RAG responses)"""
async_gen = self._send_request_async_streaming(service, flow, request) async_gen = self._send_request_async_streaming(service, flow, request, include_provenance)
try: try:
while True: while True:
@ -265,7 +324,8 @@ class SocketClient:
self, self,
service: str, service: str,
flow: Optional[str], flow: Optional[str],
request: Dict[str, Any] request: Dict[str, Any],
include_provenance: bool = False
) -> Iterator[StreamingChunk]: ) -> Iterator[StreamingChunk]:
"""Async implementation of WebSocket request (streaming)""" """Async implementation of WebSocket request (streaming)"""
# Generate unique request ID # Generate unique request ID
@ -309,8 +369,8 @@ class SocketClient:
raise_from_error_dict(resp["error"]) raise_from_error_dict(resp["error"])
# Parse different chunk types # Parse different chunk types
chunk = self._parse_chunk(resp) chunk = self._parse_chunk(resp, include_provenance=include_provenance)
if chunk is not None: # Skip provenance messages in streaming if chunk is not None: # Skip provenance messages unless include_provenance
yield chunk yield chunk
# Check if this is the final message # Check if this is the final message
@ -325,14 +385,26 @@ class SocketClient:
chunk_type = resp.get("chunk_type") chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type") message_type = resp.get("message_type")
# Handle new GraphRAG message format with message_type # Handle GraphRAG/DocRAG message format with message_type
if message_type == "provenance": if message_type == "explain":
if include_provenance: if include_provenance:
# Return provenance event for explainability # Return provenance event for explainability
return ProvenanceEvent(provenance_id=resp.get("provenance_id", "")) return ProvenanceEvent(
explain_id=resp.get("explain_id", ""),
explain_graph=resp.get("explain_graph", "")
)
# Provenance messages are not yielded to user - they're metadata # Provenance messages are not yielded to user - they're metadata
return None return None
# Handle Agent message format with chunk_type="explain"
if chunk_type == "explain":
if include_provenance:
return ProvenanceEvent(
explain_id=resp.get("explain_id", ""),
explain_graph=resp.get("explain_graph", "")
)
return None
if chunk_type == "thought": if chunk_type == "thought":
return AgentThought( return AgentThought(
content=resp.get("content", ""), content=resp.get("content", ""),
@ -477,6 +549,95 @@ class SocketFlowInstance:
# regardless of streaming flag, so always use the streaming code path # regardless of streaming flag, so always use the streaming code path
return self.client._send_request_sync("agent", self.flow_id, request, streaming=True) return self.client._send_request_sync("agent", self.flow_id, request, streaming=True)
def agent_explain(
self,
question: str,
user: str,
collection: str,
state: Optional[Dict[str, Any]] = None,
group: Optional[str] = None,
history: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any
) -> Iterator[Union[StreamingChunk, ProvenanceEvent]]:
"""
Execute an agent operation with explainability support.
Streams both content chunks (AgentThought, AgentObservation, AgentAnswer)
and provenance events (ProvenanceEvent). Provenance events contain URIs
that can be fetched using ExplainabilityClient to get detailed information
about the agent's reasoning process.
Agent trace consists of:
- Session: The initial question and session metadata
- Iterations: Each thought/action/observation cycle
- Conclusion: The final answer
Args:
question: User question or instruction
user: User identifier
collection: Collection identifier for provenance storage
state: Optional state dictionary for stateful conversations
group: Optional group identifier for multi-user contexts
history: Optional conversation history as list of message dicts
**kwargs: Additional parameters passed to the agent service
Yields:
Union[StreamingChunk, ProvenanceEvent]: Agent chunks and provenance events
Example:
```python
from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent
from trustgraph.api import AgentThought, AgentObservation, AgentAnswer
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
provenance_ids = []
for item in flow.agent_explain(
question="What is the capital of France?",
user="trustgraph",
collection="default"
):
if isinstance(item, AgentThought):
print(f"[Thought] {item.content}")
elif isinstance(item, AgentObservation):
print(f"[Observation] {item.content}")
elif isinstance(item, AgentAnswer):
print(f"[Answer] {item.content}")
elif isinstance(item, ProvenanceEvent):
provenance_ids.append(item.explain_id)
# Fetch session trace after completion
if provenance_ids:
trace = explain_client.fetch_agent_trace(
provenance_ids[0], # Session URI is first
graph="urn:graph:retrieval",
user="trustgraph",
collection="default"
)
```
"""
request = {
"question": question,
"user": user,
"collection": collection,
"streaming": True # Always streaming for explain
}
if state is not None:
request["state"] = state
if group is not None:
request["group"] = group
if history is not None:
request["history"] = history
request.update(kwargs)
# Use streaming with provenance enabled
return self.client._send_request_sync(
"agent", self.flow_id, request,
streaming=True, include_provenance=True
)
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
""" """
Execute text completion with optional streaming. Execute text completion with optional streaming.
@ -596,6 +757,86 @@ class SocketFlowInstance:
else: else:
return result.get("response", "") return result.get("response", "")
def graph_rag_explain(
self,
query: str,
user: str,
collection: str,
max_subgraph_size: int = 1000,
max_subgraph_count: int = 5,
max_entity_distance: int = 3,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""
Execute graph-based RAG query with explainability support.
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
Provenance events contain URIs that can be fetched using ExplainabilityClient
to get detailed information about how the response was generated.
Args:
query: Natural language query
user: User/keyspace identifier
collection: Collection identifier
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
max_subgraph_count: Maximum number of subgraphs (default: 5)
max_entity_distance: Maximum traversal depth (default: 3)
**kwargs: Additional parameters passed to the service
Yields:
Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events
Example:
```python
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
provenance_ids = []
response_text = ""
for item in flow.graph_rag_explain(
query="Tell me about Marie Curie",
user="trustgraph",
collection="scientists"
):
if isinstance(item, RAGChunk):
response_text += item.content
print(item.content, end='', flush=True)
elif isinstance(item, ProvenanceEvent):
provenance_ids.append(item.provenance_id)
# Fetch explainability details
for prov_id in provenance_ids:
entity = explain_client.fetch_entity(
prov_id,
graph="urn:graph:retrieval",
user="trustgraph",
collection="scientists"
)
print(f"Entity: {entity}")
```
"""
request = {
"query": query,
"user": user,
"collection": collection,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"streaming": True,
"explainable": True, # Enable explainability mode
}
request.update(kwargs)
# Use streaming with provenance events included
return self.client._send_request_sync(
"graph-rag", self.flow_id, request,
streaming=True, include_provenance=True
)
def document_rag( def document_rag(
self, self,
query: str, query: str,
@ -654,6 +895,79 @@ class SocketFlowInstance:
else: else:
return result.get("response", "") return result.get("response", "")
def document_rag_explain(
self,
query: str,
user: str,
collection: str,
doc_limit: int = 10,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""
Execute document-based RAG query with explainability support.
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
Provenance events contain URIs that can be fetched using ExplainabilityClient
to get detailed information about how the response was generated.
Document RAG trace consists of:
- Question: The user's query
- Exploration: Chunks retrieved from document store (chunk_count)
- Synthesis: The generated answer
Args:
query: Natural language query
user: User/keyspace identifier
collection: Collection identifier
doc_limit: Maximum document chunks to retrieve (default: 10)
**kwargs: Additional parameters passed to the service
Yields:
Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events
Example:
```python
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
for item in flow.document_rag_explain(
query="Summarize the key findings",
user="trustgraph",
collection="research-papers",
doc_limit=5
):
if isinstance(item, RAGChunk):
print(item.content, end='', flush=True)
elif isinstance(item, ProvenanceEvent):
# Fetch entity details
entity = explain_client.fetch_entity(
item.explain_id,
graph=item.explain_graph,
user="trustgraph",
collection="research-papers"
)
print(f"Event: {entity}", file=sys.stderr)
```
"""
request = {
"query": query,
"user": user,
"collection": collection,
"doc-limit": doc_limit,
"streaming": True,
"explainable": True,
}
request.update(kwargs)
# Use streaming with provenance events included
return self.client._send_request_sync(
"document-rag", self.flow_id, request,
streaming=True, include_provenance=True
)
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
"""Generator for RAG streaming (graph-rag and document-rag)""" """Generator for RAG streaming (graph-rag and document-rag)"""
for chunk in result: for chunk in result:
@ -831,28 +1145,30 @@ class SocketFlowInstance:
def triples_query( def triples_query(
self, self,
s: Optional[str] = None, s: Optional[Union[str, Dict[str, Any]]] = None,
p: Optional[str] = None, p: Optional[Union[str, Dict[str, Any]]] = None,
o: Optional[str] = None, o: Optional[Union[str, Dict[str, Any]]] = None,
g: Optional[str] = None,
user: Optional[str] = None, user: Optional[str] = None,
collection: Optional[str] = None, collection: Optional[str] = None,
limit: int = 100, limit: int = 100,
**kwargs: Any **kwargs: Any
) -> Dict[str, Any]: ) -> List[Dict[str, Any]]:
""" """
Query knowledge graph triples using pattern matching. Query knowledge graph triples using pattern matching.
Args: Args:
s: Subject URI (optional, use None for wildcard) s: Subject filter - URI string, Term dict, or None for wildcard
p: Predicate URI (optional, use None for wildcard) p: Predicate filter - URI string, Term dict, or None for wildcard
o: Object URI or Literal (optional, use None for wildcard) o: Object filter - URI/literal string, Term dict, or None for wildcard
g: Named graph filter - URI string or None for all graphs
user: User/keyspace identifier (optional) user: User/keyspace identifier (optional)
collection: Collection identifier (optional) collection: Collection identifier (optional)
limit: Maximum results to return (default: 100) limit: Maximum results to return (default: 100)
**kwargs: Additional parameters passed to the service **kwargs: Additional parameters passed to the service
Returns: Returns:
dict: Query results with matching triples List[Dict]: List of matching triples in wire format
Example: Example:
```python ```python
@ -860,33 +1176,54 @@ class SocketFlowInstance:
flow = socket.flow("default") flow = socket.flow("default")
# Find all triples about a specific subject # Find all triples about a specific subject
result = flow.triples_query( triples = flow.triples_query(
s="http://example.org/person/marie-curie", s="http://example.org/person/marie-curie",
user="trustgraph", user="trustgraph",
collection="scientists" collection="scientists"
) )
# Query with named graph filter
triples = flow.triples_query(
s="urn:trustgraph:session:abc123",
g="urn:graph:retrieval",
user="trustgraph",
collection="default"
)
``` ```
""" """
request = {"limit": limit} request = {"limit": limit}
if s is not None:
request["s"] = str(s) # Build Term dicts for s/p/o (auto-converts strings)
if p is not None: s_term = build_term(s)
request["p"] = str(p) p_term = build_term(p)
if o is not None: o_term = build_term(o)
request["o"] = str(o)
if s_term is not None:
request["s"] = s_term
if p_term is not None:
request["p"] = p_term
if o_term is not None:
request["o"] = o_term
if g is not None:
request["g"] = g
if user is not None: if user is not None:
request["user"] = user request["user"] = user
if collection is not None: if collection is not None:
request["collection"] = collection request["collection"] = collection
request.update(kwargs) request.update(kwargs)
return self.client._send_request_sync("triples", self.flow_id, request, False) result = self.client._send_request_sync("triples", self.flow_id, request, False)
# Return the triples list from the response
if isinstance(result, dict) and "response" in result:
return result["response"]
return result
def triples_query_stream( def triples_query_stream(
self, self,
s: Optional[str] = None, s: Optional[Union[str, Dict[str, Any]]] = None,
p: Optional[str] = None, p: Optional[Union[str, Dict[str, Any]]] = None,
o: Optional[str] = None, o: Optional[Union[str, Dict[str, Any]]] = None,
g: Optional[str] = None,
user: Optional[str] = None, user: Optional[str] = None,
collection: Optional[str] = None, collection: Optional[str] = None,
limit: int = 100, limit: int = 100,
@ -900,9 +1237,10 @@ class SocketFlowInstance:
and memory overhead for large result sets. and memory overhead for large result sets.
Args: Args:
s: Subject URI (optional, use None for wildcard) s: Subject filter - URI string, Term dict, or None for wildcard
p: Predicate URI (optional, use None for wildcard) p: Predicate filter - URI string, Term dict, or None for wildcard
o: Object URI or Literal (optional, use None for wildcard) o: Object filter - URI/literal string, Term dict, or None for wildcard
g: Named graph filter - URI string or None for all graphs
user: User/keyspace identifier (optional) user: User/keyspace identifier (optional)
collection: Collection identifier (optional) collection: Collection identifier (optional)
limit: Maximum results to return (default: 100) limit: Maximum results to return (default: 100)
@ -930,12 +1268,20 @@ class SocketFlowInstance:
"streaming": True, "streaming": True,
"batch-size": batch_size, "batch-size": batch_size,
} }
if s is not None:
request["s"] = str(s) # Build Term dicts for s/p/o (auto-converts strings)
if p is not None: s_term = build_term(s)
request["p"] = str(p) p_term = build_term(p)
if o is not None: o_term = build_term(o)
request["o"] = str(o)
if s_term is not None:
request["s"] = s_term
if p_term is not None:
request["p"] = p_term
if o_term is not None:
request["o"] = o_term
if g is not None:
request["g"] = g
if user is not None: if user is not None:
request["user"] = user request["user"] = user
if collection is not None: if collection is not None:

View file

@ -212,19 +212,21 @@ class ProvenanceEvent:
Each event represents a provenance node created during query processing. Each event represents a provenance node created during query processing.
Attributes: Attributes:
provenance_id: URI of the provenance node (e.g., urn:trustgraph:session:abc123) explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123)
event_type: Type of provenance event (session, retrieval, selection, answer) explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval)
event_type: Type of provenance event (question, exploration, focus, synthesis)
""" """
provenance_id: str explain_id: str
event_type: str = "" # Derived from provenance_id (session, retrieval, selection, answer) explain_graph: str = ""
event_type: str = "" # Derived from explain_id
def __post_init__(self): def __post_init__(self):
# Extract event type from provenance_id # Extract event type from explain_id
if "session" in self.provenance_id: if "question" in self.explain_id:
self.event_type = "session" self.event_type = "question"
elif "retrieval" in self.provenance_id: elif "exploration" in self.explain_id:
self.event_type = "retrieval" self.event_type = "exploration"
elif "selection" in self.provenance_id: elif "focus" in self.explain_id:
self.event_type = "selection" self.event_type = "focus"
elif "answer" in self.provenance_id: elif "synthesis" in self.explain_id:
self.event_type = "answer" self.event_type = "synthesis"

View file

@ -59,6 +59,15 @@ class AgentResponseTranslator(MessageTranslator):
result["end_of_message"] = getattr(obj, "end_of_message", False) result["end_of_message"] = getattr(obj, "end_of_message", False)
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False) result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
# Include explainability fields if present
explain_id = getattr(obj, "explain_id", None)
if explain_id:
result["explain_id"] = explain_id
explain_graph = getattr(obj, "explain_graph", None)
if explain_graph is not None:
result["explain_graph"] = explain_graph
# Always include error if present # Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message: if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "code": obj.error.code} result["error"] = {"message": obj.error.message, "code": obj.error.code}

View file

@ -34,7 +34,12 @@ class DocumentRagResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
result = {} result = {}
# Include response content (even if empty string) # Include message_type for distinguishing chunk vs explain messages
message_type = getattr(obj, "message_type", "")
if message_type:
result["message_type"] = message_type
# Include response content for chunk messages
if obj.response is not None: if obj.response is not None:
result["response"] = obj.response result["response"] = obj.response
@ -48,9 +53,12 @@ class DocumentRagResponseTranslator(MessageTranslator):
if explain_graph is not None: if explain_graph is not None:
result["explain_graph"] = explain_graph result["explain_graph"] = explain_graph
# Include end_of_stream flag # Include end_of_stream flag (LLM stream complete)
result["end_of_stream"] = getattr(obj, "end_of_stream", False) result["end_of_stream"] = getattr(obj, "end_of_stream", False)
# Include end_of_session flag (entire session complete)
result["end_of_session"] = getattr(obj, "end_of_session", False)
# Always include error if present # Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message: if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type} result["error"] = {"message": obj.error.message, "type": obj.error.type}
@ -59,7 +67,8 @@ class DocumentRagResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)""" """Returns (response_dict, is_final)"""
is_final = getattr(obj, 'end_of_stream', False) # Session is complete when end_of_session is True
is_final = getattr(obj, 'end_of_session', False)
return self.from_pulsar(obj), is_final return self.from_pulsar(obj), is_final

View file

@ -82,6 +82,10 @@ from . namespaces import (
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
# Agent provenance predicates # Agent provenance predicates
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
# Agent document references
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
# Document reference predicate
TG_DOCUMENT,
# Named graphs # Named graphs
GRAPH_DEFAULT, GRAPH_SOURCE, GRAPH_RETRIEVAL, GRAPH_DEFAULT, GRAPH_SOURCE, GRAPH_RETRIEVAL,
) )
@ -165,6 +169,10 @@ __all__ = [
"TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION", "TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION",
# Agent provenance predicates # Agent provenance predicates
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER", "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER",
# Agent document references
"TG_THOUGHT_DOCUMENT", "TG_OBSERVATION_DOCUMENT",
# Document reference predicate
"TG_DOCUMENT",
# Named graphs # Named graphs
"GRAPH_DEFAULT", "GRAPH_SOURCE", "GRAPH_RETRIEVAL", "GRAPH_DEFAULT", "GRAPH_SOURCE", "GRAPH_RETRIEVAL",
# Triple builders # Triple builders

View file

@ -17,7 +17,8 @@ from . namespaces import (
RDF_TYPE, RDFS_LABEL, RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME, PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_AGENT_QUESTION, TG_AGENT_QUESTION,
) )
@ -73,10 +74,12 @@ def agent_session_triples(
def agent_iteration_triples( def agent_iteration_triples(
iteration_uri: str, iteration_uri: str,
parent_uri: str, parent_uri: str,
thought: str, thought: str = "",
action: str, action: str = "",
arguments: Dict[str, Any], arguments: Dict[str, Any] = None,
observation: str, observation: str = "",
thought_document_id: Optional[str] = None,
observation_document_id: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for one agent iteration (Analysis - think/act/observe cycle). Build triples for one agent iteration (Analysis - think/act/observe cycle).
@ -85,36 +88,53 @@ def agent_iteration_triples(
- Entity declaration with tg:Analysis type - Entity declaration with tg:Analysis type
- wasDerivedFrom link to parent (previous iteration or session) - wasDerivedFrom link to parent (previous iteration or session)
- Thought, action, arguments, and observation data - Thought, action, arguments, and observation data
- Document references for thought/observation when stored in librarian
Args: Args:
iteration_uri: URI of this iteration (from agent_iteration_uri) iteration_uri: URI of this iteration (from agent_iteration_uri)
parent_uri: URI of the parent (previous iteration or session) parent_uri: URI of the parent (previous iteration or session)
thought: The agent's reasoning/thought thought: The agent's reasoning/thought (used if thought_document_id not provided)
action: The tool/action name action: The tool/action name
arguments: Arguments passed to the tool (will be JSON-encoded) arguments: Arguments passed to the tool (will be JSON-encoded)
observation: The result/observation from the tool observation: The result/observation from the tool (used if observation_document_id not provided)
thought_document_id: Optional document URI for thought in librarian (preferred)
observation_document_id: Optional document URI for observation in librarian (preferred)
Returns: Returns:
List of Triple objects List of Triple objects
""" """
if arguments is None:
arguments = {}
triples = [ triples = [
_triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)), _triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)),
_triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")), _triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")),
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)), _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
_triple(iteration_uri, TG_THOUGHT, _literal(thought)),
_triple(iteration_uri, TG_ACTION, _literal(action)), _triple(iteration_uri, TG_ACTION, _literal(action)),
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
_triple(iteration_uri, TG_OBSERVATION, _literal(observation)),
] ]
# Thought: use document reference or inline
if thought_document_id:
triples.append(_triple(iteration_uri, TG_THOUGHT_DOCUMENT, _iri(thought_document_id)))
elif thought:
triples.append(_triple(iteration_uri, TG_THOUGHT, _literal(thought)))
# Observation: use document reference or inline
if observation_document_id:
triples.append(_triple(iteration_uri, TG_OBSERVATION_DOCUMENT, _iri(observation_document_id)))
elif observation:
triples.append(_triple(iteration_uri, TG_OBSERVATION, _literal(observation)))
return triples return triples
def agent_final_triples( def agent_final_triples(
final_uri: str, final_uri: str,
parent_uri: str, parent_uri: str,
answer: str, answer: str = "",
document_id: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for an agent final answer (Conclusion). Build triples for an agent final answer (Conclusion).
@ -122,20 +142,29 @@ def agent_final_triples(
Creates: Creates:
- Entity declaration with tg:Conclusion type - Entity declaration with tg:Conclusion type
- wasDerivedFrom link to parent (last iteration or session) - wasDerivedFrom link to parent (last iteration or session)
- The answer text - Either document reference (if document_id provided) or inline answer
Args: Args:
final_uri: URI of the final answer (from agent_final_uri) final_uri: URI of the final answer (from agent_final_uri)
parent_uri: URI of the parent (last iteration or session if no iterations) parent_uri: URI of the parent (last iteration or session if no iterations)
answer: The final answer text answer: The final answer text (used if document_id not provided)
document_id: Optional document URI in librarian (preferred)
Returns: Returns:
List of Triple objects List of Triple objects
""" """
return [ triples = [
_triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)), _triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)),
_triple(final_uri, RDFS_LABEL, _literal("Conclusion")), _triple(final_uri, RDFS_LABEL, _literal("Conclusion")),
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)), _triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
_triple(final_uri, TG_ANSWER, _literal(answer)),
] ]
if document_id:
# Store reference to document in librarian (as IRI)
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
elif answer:
# Fallback: store inline answer
triples.append(_triple(final_uri, TG_ANSWER, _literal(answer)))
return triples

View file

@ -92,6 +92,10 @@ TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation" TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer" TG_ANSWER = TG + "answer"
# Agent document references (for librarian storage)
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
# Named graph URIs for RDF datasets # Named graph URIs for RDF datasets
# These separate different types of data while keeping them in the same collection # These separate different types of data while keeping them in the same collection
GRAPH_DEFAULT = "" # Core knowledge facts (triples extracted from documents) GRAPH_DEFAULT = "" # Core knowledge facts (triples extracted from documents)

View file

@ -30,11 +30,15 @@ class AgentRequest:
@dataclass @dataclass
class AgentResponse: class AgentResponse:
# Streaming-first design # Streaming-first design
chunk_type: str = "" # "thought", "action", "observation", "answer", "error" chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error"
content: str = "" # The actual content (interpretation depends on chunk_type) content: str = "" # The actual content (interpretation depends on chunk_type)
end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete
end_of_dialog: bool = False # Entire agent dialog is complete end_of_dialog: bool = False # Entire agent dialog is complete
# Explainability fields
explain_id: str | None = None # Provenance URI (announced as created)
explain_graph: str | None = None # Named graph where explain was stored
# Legacy fields (deprecated but kept for backward compatibility) # Legacy fields (deprecated but kept for backward compatibility)
answer: str = "" answer: str = ""
error: Error | None = None error: Error | None = None

View file

@ -43,6 +43,8 @@ class DocumentRagQuery:
class DocumentRagResponse: class DocumentRagResponse:
error: Error | None = None error: Error | None = None
response: str | None = "" response: str | None = ""
end_of_stream: bool = False end_of_stream: bool = False # LLM response stream complete
explain_id: str | None = None # Single explain URI (announced as created) explain_id: str | None = None # Single explain URI (announced as created)
explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval) explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval)
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete

View file

@ -4,8 +4,19 @@ Uses the agent service to answer a question
import argparse import argparse
import os import os
import sys
import textwrap import textwrap
from trustgraph.api import Api from trustgraph.api import (
Api,
ExplainabilityClient,
ProvenanceEvent,
Question,
Analysis,
Conclusion,
AgentThought,
AgentObservation,
AgentAnswer,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -97,11 +108,148 @@ def output(text, prefix="> ", width=78):
) )
print(out) print(out)
def question_explainable(
url, question_text, flow_id, user, collection,
state=None, group=None, verbose=False, token=None, debug=False
):
"""Execute agent with explainability - shows provenance events inline."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Track last chunk type for formatting
last_chunk_type = None
current_outputter = None
# Stream agent with explainability - process events as they arrive
for item in flow.agent_explain(
question=question_text,
user=user,
collection=collection,
state=state,
group=group,
):
if isinstance(item, AgentThought):
if last_chunk_type != "thought":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print() # Blank line between message types
if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f914 ")
current_outputter.__enter__()
last_chunk_type = "thought"
if current_outputter:
current_outputter.output(item.content)
if current_outputter.word_buffer:
print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = ""
elif isinstance(item, AgentObservation):
if last_chunk_type != "observation":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print()
if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
current_outputter.__enter__()
last_chunk_type = "observation"
if current_outputter:
current_outputter.output(item.content)
if current_outputter.word_buffer:
print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = ""
elif isinstance(item, AgentAnswer):
if last_chunk_type != "answer":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print()
last_chunk_type = "answer"
# Print answer content directly
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [session] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Analysis):
print(f"\n [iteration] {prov_id}", file=sys.stderr)
if entity.thought:
thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought
print(f" Thought: {thought_short}", file=sys.stderr)
if entity.action:
print(f" Action: {entity.action}", file=sys.stderr)
elif isinstance(entity, Conclusion):
print(f"\n [conclusion] {prov_id}", file=sys.stderr)
if entity.answer:
print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
# Close any remaining outputter
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
# Final newline if we ended with answer
if last_chunk_type == "answer":
print()
finally:
socket.close()
def question( def question(
url, question, flow_id, user, collection, url, question, flow_id, user, collection,
plan=None, state=None, group=None, verbose=False, streaming=True, plan=None, state=None, group=None, verbose=False, streaming=True,
token=None token=None, explainable=False, debug=False
): ):
# Explainable mode uses the API to capture and process provenance events
if explainable:
question_explainable(
url=url,
question_text=question,
flow_id=flow_id,
user=user,
collection=collection,
state=state,
group=group,
verbose=verbose,
token=token,
debug=debug
)
return
if verbose: if verbose:
output(wrap(question), "\U00002753 ") output(wrap(question), "\U00002753 ")
@ -270,6 +418,18 @@ def main():
help=f'Disable streaming (use legacy mode)' help=f'Disable streaming (use legacy mode)'
) )
parser.add_argument(
'-x', '--explainable',
action='store_true',
help='Show provenance events: Session, Iterations, Conclusion (implies streaming)'
)
parser.add_argument(
'--debug',
action='store_true',
help='Show debug output for troubleshooting'
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -286,6 +446,8 @@ def main():
verbose = args.verbose, verbose = args.verbose,
streaming = not args.no_streaming, streaming = not args.no_streaming,
token = args.token, token = args.token,
explainable = args.explainable,
debug = args.debug,
) )
except Exception as e: except Exception as e:

View file

@ -4,7 +4,16 @@ Uses the DocumentRAG service to answer a question
import argparse import argparse
import os import os
from trustgraph.api import Api import sys
from trustgraph.api import (
Api,
ExplainabilityClient,
RAGChunk,
ProvenanceEvent,
Question,
Exploration,
Synthesis,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -12,7 +21,90 @@ default_user = 'trustgraph'
default_collection = 'default' default_collection = 'default'
default_doc_limit = 10 default_doc_limit = 10
def question(url, flow_id, question, user, collection, doc_limit, streaming=True, token=None):
def question_explainable(
url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False
):
"""Execute document RAG with explainability - shows provenance events inline."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Stream DocumentRAG with explainability - process events as they arrive
for item in flow.document_rag_explain(
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
):
if isinstance(item, RAGChunk):
# Print response content
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [question] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.chunk_count:
print(f" Chunks retrieved: {entity.chunk_count}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
print() # Final newline
finally:
socket.close()
def question(
url, flow_id, question_text, user, collection, doc_limit,
streaming=True, token=None, explainable=False, debug=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
question_explainable(
url=url,
flow_id=flow_id,
question_text=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
token=token,
debug=debug
)
return
# Create API client # Create API client
api = Api(url=url, token=token) api = Api(url=url, token=token)
@ -24,7 +116,7 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
try: try:
response = flow.document_rag( response = flow.document_rag(
query=question, query=question_text,
user=user, user=user,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
@ -42,13 +134,14 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
# Use REST API for non-streaming # Use REST API for non-streaming
flow = api.flow().id(flow_id) flow = api.flow().id(flow_id)
resp = flow.document_rag( resp = flow.document_rag(
query=question, query=question_text,
user=user, user=user,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
) )
print(resp) print(resp)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -105,6 +198,18 @@ def main():
help='Disable streaming (use non-streaming mode)' help='Disable streaming (use non-streaming mode)'
) )
parser.add_argument(
'-x', '--explainable',
action='store_true',
help='Show provenance events: Question, Exploration, Synthesis (implies streaming)'
)
parser.add_argument(
'--debug',
action='store_true',
help='Show debug output for troubleshooting'
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -112,12 +217,14 @@ def main():
question( question(
url=args.url, url=args.url,
flow_id=args.flow_id, flow_id=args.flow_id,
question=args.question, question_text=args.question,
user=args.user, user=args.user,
collection=args.collection, collection=args.collection,
doc_limit=args.doc_limit, doc_limit=args.doc_limit,
streaming=not args.no_streaming, streaming=not args.no_streaming,
token=args.token, token=args.token,
explainable=args.explainable,
debug=args.debug,
) )
except Exception as e: except Exception as e:

View file

@ -8,7 +8,16 @@ import os
import sys import sys
import websockets import websockets
import asyncio import asyncio
from trustgraph.api import Api from trustgraph.api import (
Api,
ExplainabilityClient,
RAGChunk,
ProvenanceEvent,
Question,
Exploration,
Focus,
Synthesis,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -602,18 +611,111 @@ async def _question_explainable(
print() # Final newline print() # Final newline
def _question_explainable_api(
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
):
"""Execute graph RAG with explainability using the new API classes."""
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
try:
# Stream GraphRAG with explainability - process events as they arrive
for item in flow.graph_rag_explain(
query=question_text,
user=user,
collection=collection,
max_subgraph_size=max_subgraph_size,
max_subgraph_count=5,
max_entity_distance=max_path_length,
):
if isinstance(item, RAGChunk):
# Print response content
print(item.content, end="", flush=True)
elif isinstance(item, ProvenanceEvent):
# Process provenance event immediately
prov_id = item.explain_id
explain_graph = item.explain_graph or "urn:graph:retrieval"
entity = explain_client.fetch_entity(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if entity is None:
if debug:
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
continue
# Display based on entity type
if isinstance(entity, Question):
print(f"\n [question] {prov_id}", file=sys.stderr)
if entity.query:
print(f" Query: {entity.query}", file=sys.stderr)
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.edge_count:
print(f" Edges explored: {entity.edge_count}", file=sys.stderr)
elif isinstance(entity, Focus):
print(f"\n [focus] {prov_id}", file=sys.stderr)
if entity.selected_edge_uris:
print(f" Focused on {len(entity.selected_edge_uris)} edge(s)", file=sys.stderr)
# Fetch full focus with edge details
focus_full = explain_client.fetch_focus_with_edges(
prov_id,
graph=explain_graph,
user=user,
collection=collection
)
if focus_full and focus_full.edge_selections:
for edge_sel in focus_full.edge_selections:
if edge_sel.edge:
# Resolve labels for edge components
s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, user, collection
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reason: {r_short}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
else:
if debug:
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
print() # Final newline
finally:
socket.close()
def question( def question(
url, flow_id, question, user, collection, entity_limit, triple_limit, url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, streaming=True, token=None, max_subgraph_size, max_path_length, streaming=True, token=None,
explainable=False, debug=False explainable=False, debug=False
): ):
# Explainable mode uses direct websocket to capture provenance events # Explainable mode uses the API to capture and process provenance events
if explainable: if explainable:
asyncio.run(_question_explainable( _question_explainable_api(
url=url, url=url,
flow_id=flow_id, flow_id=flow_id,
question=question, question_text=question,
user=user, user=user,
collection=collection, collection=collection,
entity_limit=entity_limit, entity_limit=entity_limit,
@ -622,7 +724,7 @@ def question(
max_path_length=max_path_length, max_path_length=max_path_length,
token=token, token=token,
debug=debug debug=debug
)) )
return return
# Create API client # Create API client

View file

@ -14,180 +14,17 @@ import json
import os import os
import sys import sys
from tabulate import tabulate from tabulate import tabulate
from trustgraph.api import Api from trustgraph.api import Api, ExplainabilityClient
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph' default_user = 'trustgraph'
default_collection = 'default' default_collection = 'default'
# Predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_QUESTION = TG + "Question"
TG_ANALYSIS = TG + "Analysis"
TG_EXPLORATION = TG + "Exploration"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
# Retrieval graph # Retrieval graph
RETRIEVAL_GRAPH = "urn:graph:retrieval" RETRIEVAL_GRAPH = "urn:graph:retrieval"
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
"""Query triples using the socket API."""
request = {
"user": user,
"collection": collection,
"limit": limit,
"streaming": False,
}
if s is not None:
request["s"] = {"t": "i", "i": s}
if p is not None:
request["p"] = {"t": "i", "i": p}
if o is not None:
if isinstance(o, str):
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
request["o"] = {"t": "i", "i": o}
else:
request["o"] = {"t": "l", "v": o}
elif isinstance(o, dict):
request["o"] = o
if g is not None:
request["g"] = g
triples = []
try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
if isinstance(response, dict):
triple_list = response.get("response", response.get("triples", []))
else:
triple_list = response
if not isinstance(triple_list, list):
triple_list = [triple_list] if triple_list else []
for t in triple_list:
s_val = extract_value(t.get("s", {}))
p_val = extract_value(t.get("p", {}))
o_val = extract_value(t.get("o", {}))
triples.append((s_val, p_val, o_val))
except Exception as e:
print(f"Error querying triples: {e}", file=sys.stderr)
return triples
def extract_value(term):
"""Extract value from a term dict."""
if not term:
return ""
t = term.get("t") or term.get("type")
if t == "i":
return term.get("i") or term.get("iri", "")
elif t == "l":
return term.get("v") or term.get("value", "")
elif t == "t":
# Quoted triple
tr = term.get("tr") or term.get("triple", {})
return {
"s": extract_value(tr.get("s", {})),
"p": extract_value(tr.get("p", {})),
"o": extract_value(tr.get("o", {})),
}
# Fallback for raw values
if "i" in term:
return term["i"]
if "v" in term:
return term["v"]
return str(term)
def get_timestamp(socket, flow_id, user, collection, question_id):
"""Get timestamp for a question."""
triples = query_triples(
socket, flow_id, user, collection,
s=question_id, p=PROV_STARTED_AT_TIME, g=RETRIEVAL_GRAPH
)
for s, p, o in triples:
return o
return ""
def get_session_type(socket, flow_id, user, collection, session_id):
"""
Get the type of session (Agent or GraphRAG).
Both have tg:Question type, so we distinguish by URI pattern
or by checking what's derived from it.
"""
# Fast path: check URI pattern
if session_id.startswith("urn:trustgraph:agent:"):
return "Agent"
if session_id.startswith("urn:trustgraph:question:"):
return "GraphRAG"
# Check what's derived from this entity
derived = query_triples(
socket, flow_id, user, collection,
p=PROV_WAS_DERIVED_FROM, o=session_id, g=RETRIEVAL_GRAPH
)
generated = query_triples(
socket, flow_id, user, collection,
p=PROV_WAS_GENERATED_BY, o=session_id, g=RETRIEVAL_GRAPH
)
for s, p, o in derived + generated:
child_types = query_triples(
socket, flow_id, user, collection,
s=s, p=RDF_TYPE, g=RETRIEVAL_GRAPH
)
for _, _, child_type in child_types:
if child_type == TG_ANALYSIS:
return "Agent"
if child_type == TG_EXPLORATION:
return "GraphRAG"
return "GraphRAG"
def list_sessions(socket, flow_id, user, collection, limit):
"""List all explainability sessions (GraphRAG and Agent) by finding questions."""
# Query for all triples with predicate = tg:query
triples = query_triples(
socket, flow_id, user, collection,
p=TG_QUERY, g=RETRIEVAL_GRAPH, limit=limit
)
sessions = []
for question_id, _, query_text in triples:
# Get timestamp if available
timestamp = get_timestamp(socket, flow_id, user, collection, question_id)
# Get session type (Agent or GraphRAG)
session_type = get_session_type(socket, flow_id, user, collection, question_id)
sessions.append({
"id": question_id,
"type": session_type,
"question": query_text,
"time": timestamp,
})
# Sort by timestamp (newest first) if available
sessions.sort(key=lambda x: x.get("time", ""), reverse=True)
return sessions
def truncate_text(text, max_len=60): def truncate_text(text, max_len=60):
"""Truncate text to max length with ellipsis.""" """Truncate text to max length with ellipsis."""
if not text: if not text:
@ -277,16 +114,42 @@ def main():
try: try:
api = Api(args.api_url, token=args.token) api = Api(args.api_url, token=args.token)
socket = api.socket() socket = api.socket()
flow = socket.flow(args.flow_id)
explain_client = ExplainabilityClient(flow)
try: try:
sessions = list_sessions( # List all sessions using the API
socket=socket, questions = explain_client.list_sessions(
flow_id=args.flow_id, graph=RETRIEVAL_GRAPH,
user=args.user, user=args.user,
collection=args.collection, collection=args.collection,
limit=args.limit, limit=args.limit,
) )
# Convert to output format
sessions = []
for q in questions:
session_type = explain_client.detect_session_type(
q.uri,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection
)
# Map type names
type_display = {
"graphrag": "GraphRAG",
"docrag": "DocRAG",
"agent": "Agent",
}.get(session_type, session_type.title())
sessions.append({
"id": q.uri,
"type": type_display,
"question": q.query,
"time": q.timestamp,
})
if args.format == 'json': if args.format == 'json':
print_json(sessions) print_json(sessions)
else: else:

View file

@ -291,42 +291,25 @@ def query_graph(
): ):
"""Query the triple store with pattern matching. """Query the triple store with pattern matching.
Uses the WebSocket API's raw streaming mode for efficient delivery of results. Uses the API's triples_query_stream for efficient streaming delivery.
""" """
socket = Api(url, token=token).socket() socket = Api(url, token=token).socket()
flow = socket.flow(flow_id)
# Build request dict directly (bypassing triples_query_stream's string conversion)
request = {
"user": user,
"collection": collection,
"limit": limit,
"streaming": True,
"batch-size": batch_size,
}
# Add term dicts for s/p/o (None means wildcard)
if subject is not None:
request["s"] = subject
if predicate is not None:
request["p"] = predicate
if obj is not None:
request["o"] = obj
if graph is not None:
request["g"] = graph
all_triples = [] all_triples = []
try: try:
# Use raw streaming mode - yields response dicts directly # Use triples_query_stream - accepts Term dicts directly
for response in socket._send_request_sync( for triples in flow.triples_query_stream(
"triples", flow_id, request, streaming_raw=True s=subject,
p=predicate,
o=obj,
g=graph,
user=user,
collection=collection,
limit=limit,
batch_size=batch_size,
): ):
# Response may have triples in different locations depending on format
if isinstance(response, dict):
triples = response.get("response", response.get("triples", []))
else:
triples = response
if not isinstance(triples, list): if not isinstance(triples, list):
triples = [triples] if triples else [] triples = [triples] if triples else []

View file

@ -18,228 +18,99 @@ import argparse
import json import json
import os import os
import sys import sys
from trustgraph.api import Api from trustgraph.api import (
Api,
ExplainabilityClient,
Question,
Exploration,
Focus,
Synthesis,
Analysis,
Conclusion,
)
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph' default_user = 'trustgraph'
default_collection = 'default' default_collection = 'default'
# Predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document"
TG_REIFIES = TG + "reifies"
# Explainability entity types
TG_QUESTION = TG + "Question"
TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion"
# Agent predicates
TG_THOUGHT = TG + "thought"
TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
# Graphs # Graphs
RETRIEVAL_GRAPH = "urn:graph:retrieval" RETRIEVAL_GRAPH = "urn:graph:retrieval"
SOURCE_GRAPH = "urn:graph:source" SOURCE_GRAPH = "urn:graph:source"
# Provenance predicates for edge tracing
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000): TG = "https://trustgraph.ai/ns/"
"""Query triples using the socket API.""" TG_REIFIES = TG + "reifies"
request = { PROV = "http://www.w3.org/ns/prov#"
"user": user, PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
"collection": collection,
"limit": limit,
"streaming": False,
}
if s is not None:
request["s"] = {"t": "i", "i": s}
if p is not None:
request["p"] = {"t": "i", "i": p}
if o is not None:
if isinstance(o, str):
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
request["o"] = {"t": "i", "i": o}
else:
request["o"] = {"t": "l", "v": o}
elif isinstance(o, dict):
request["o"] = o
if g is not None:
request["g"] = g
triples = []
try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
if isinstance(response, dict):
triple_list = response.get("response", response.get("triples", []))
else:
triple_list = response
if not isinstance(triple_list, list):
triple_list = [triple_list] if triple_list else []
for t in triple_list:
s_val = extract_value(t.get("s", {}))
p_val = extract_value(t.get("p", {}))
o_val = extract_value(t.get("o", {}))
triples.append((s_val, p_val, o_val))
except Exception as e:
print(f"Error querying triples: {e}", file=sys.stderr)
return triples
def extract_value(term): def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client):
"""Extract value from a term dict.""" """
if not term: Trace an edge back to its source document via reification.
return ""
t = term.get("t") or term.get("type") Args:
flow: SocketFlowInstance
user: User identifier
collection: Collection identifier
edge: Dict with s, p, o keys
label_cache: Dict for caching labels
explain_client: ExplainabilityClient for label resolution
if t == "i": Returns:
return term.get("i") or term.get("iri", "") List of provenance chains, each chain is list of {uri, label}
elif t == "l": """
return term.get("v") or term.get("value", "") edge_s = edge.get("s", "")
elif t == "t": edge_p = edge.get("p", "")
# Quoted triple edge_o = edge.get("o", "")
tr = term.get("tr") or term.get("triple", {})
return {
"s": extract_value(tr.get("s", {})),
"p": extract_value(tr.get("p", {})),
"o": extract_value(tr.get("o", {})),
}
# Fallback for raw values # Build quoted triple for lookup
if "i" in term: def build_term(val):
return term["i"] if isinstance(val, str) and (val.startswith("http") or val.startswith("urn:")):
if "v" in term: return {"t": "i", "i": val}
return term["v"] return {"t": "l", "v": str(val)}
return str(term)
def get_node_properties(socket, flow_id, user, collection, node_uri, graph=RETRIEVAL_GRAPH):
"""Get all properties of a node as a dict."""
triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=graph)
props = {}
for s, p, o in triples:
if p not in props:
props[p] = []
props[p].append(o)
return props
def find_by_predicate_object(socket, flow_id, user, collection, predicate, obj, graph=RETRIEVAL_GRAPH):
"""Find subjects where predicate = obj."""
triples = query_triples(socket, flow_id, user, collection, p=predicate, o=obj, g=graph)
return [s for s, p, o in triples]
def get_label(socket, flow_id, user, collection, uri, label_cache):
"""Get label for a URI, with caching."""
if not isinstance(uri, str) or not (uri.startswith("http://") or uri.startswith("https://") or uri.startswith("urn:")):
return uri
if uri in label_cache:
return label_cache[uri]
triples = query_triples(socket, flow_id, user, collection, s=uri, p=RDFS_LABEL)
for s, p, o in triples:
label_cache[uri] = o
return o
label_cache[uri] = uri
return uri
def get_document_content(api, user, doc_id, max_content):
"""Fetch document content from librarian API."""
try:
library = api.library()
content = library.get_document_content(user=user, id=doc_id)
# Try to decode as text
try:
text = content.decode('utf-8')
if len(text) > max_content:
return text[:max_content] + "... [truncated]"
return text
except UnicodeDecodeError:
return f"[Binary: {len(content)} bytes]"
except Exception as e:
return f"[Error fetching content: {e}]"
def trace_edge_provenance(socket, flow_id, user, collection, edge_s, edge_p, edge_o, label_cache):
"""Trace an edge back to its source document via reification."""
# Build the quoted triple for lookup
quoted_triple = { quoted_triple = {
"t": "t", "t": "t",
"tr": { "tr": {
"s": {"t": "i", "i": edge_s} if isinstance(edge_s, str) and (edge_s.startswith("http") or edge_s.startswith("urn:")) else {"t": "l", "v": edge_s}, "s": build_term(edge_s),
"p": {"t": "i", "i": edge_p}, "p": build_term(edge_p),
"o": {"t": "i", "i": edge_o} if isinstance(edge_o, str) and (edge_o.startswith("http") or edge_o.startswith("urn:")) else {"t": "l", "v": edge_o}, "o": build_term(edge_o),
} }
} }
# Query: ?stmt tg:reifies <<edge>> # Query: ?stmt tg:reifies <<edge>>
request = {
"user": user,
"collection": collection,
"limit": 10,
"streaming": False,
"p": {"t": "i", "i": TG_REIFIES},
"o": quoted_triple,
"g": SOURCE_GRAPH,
}
stmt_uris = []
try: try:
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True): results = flow.triples_query(
if isinstance(response, dict): p=TG_REIFIES,
triple_list = response.get("response", response.get("triples", [])) o=quoted_triple,
else: g=SOURCE_GRAPH,
triple_list = response user=user,
collection=collection,
if not isinstance(triple_list, list): limit=10
triple_list = [triple_list] if triple_list else [] )
for t in triple_list:
s_val = extract_value(t.get("s", {}))
if s_val:
stmt_uris.append(s_val)
except Exception: except Exception:
pass return []
# For each statement, find wasDerivedFrom chain # Extract statement URIs
stmt_uris = []
for t in results:
s_term = t.get("s", {})
s_val = s_term.get("i") or s_term.get("v", "")
if s_val:
stmt_uris.append(s_val)
# For each statement, trace wasDerivedFrom chain
provenance_chains = [] provenance_chains = []
for stmt_uri in stmt_uris: for stmt_uri in stmt_uris:
chain = trace_provenance_chain(socket, flow_id, user, collection, stmt_uri, label_cache) chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client)
if chain: if chain:
provenance_chains.append(chain) provenance_chains.append(chain)
return provenance_chains return provenance_chains
def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_cache, max_depth=10): def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10):
"""Trace prov:wasDerivedFrom chain from start_uri to root.""" """Trace prov:wasDerivedFrom chain from start_uri to root."""
chain = [] chain = []
current = start_uri current = start_uri
@ -248,17 +119,32 @@ def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_c
if not current: if not current:
break break
label = get_label(socket, flow_id, user, collection, current, label_cache) # Get label
if current in label_cache:
label = label_cache[current]
else:
label = explain_client.resolve_label(current, user, collection)
label_cache[current] = label
chain.append({"uri": current, "label": label}) chain.append({"uri": current, "label": label})
# Get parent # Get parent via wasDerivedFrom
triples = query_triples( try:
socket, flow_id, user, collection, results = flow.triples_query(
s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH s=current,
) p=PROV_WAS_DERIVED_FROM,
g=SOURCE_GRAPH,
user=user,
collection=collection,
limit=1
)
except Exception:
break
parent = None parent = None
for s, p, o in triples: for t in results:
parent = o o_term = t.get("o", {})
parent = o_term.get("i") or o_term.get("v", "")
break break
if not parent or parent == current: if not parent or parent == current:
@ -276,331 +162,24 @@ def format_provenance_chain(chain):
return " -> ".join(labels) return " -> ".join(labels)
def format_edge(edge, label_cache=None, socket=None, flow_id=None, user=None, collection=None): def print_graphrag_text(trace, explain_client, flow, user, collection, show_provenance=False):
"""Format a quoted triple edge for display.""" """Print GraphRAG trace in text format."""
if not isinstance(edge, dict): question = trace.get("question")
return str(edge)
s = edge.get("s", "?") print(f"=== GraphRAG Session: {question.uri if question else 'Unknown'} ===")
p = edge.get("p", "?")
o = edge.get("o", "?")
# Get labels if available
if label_cache and socket:
s_label = get_label(socket, flow_id, user, collection, s, label_cache)
p_label = get_label(socket, flow_id, user, collection, p, label_cache)
o_label = get_label(socket, flow_id, user, collection, o, label_cache)
else:
# Shorten URIs for display
s_label = s.split("/")[-1] if "/" in str(s) else s
p_label = p.split("/")[-1] if "/" in str(p) else p
o_label = o.split("/")[-1] if "/" in str(o) else o
return f"({s_label}, {p_label}, {o_label})"
def detect_trace_type(socket, flow_id, user, collection, entity_id):
"""
Detect whether an entity is an agent Question or GraphRAG Question.
Both have rdf:type = tg:Question, so we distinguish by checking
what's derived from it:
- Agent: has tg:Analysis or tg:Conclusion derived
- GraphRAG: has tg:Exploration derived
Also checks URI pattern as fallback:
- urn:trustgraph:agent: -> agent
- urn:trustgraph:question: -> graphrag
Returns:
"agent" or "graphrag"
"""
# Check URI pattern first (fast path)
if entity_id.startswith("urn:trustgraph:agent:"):
return "agent"
if entity_id.startswith("urn:trustgraph:question:"):
return "graphrag"
# Check what's derived from this entity
derived = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, entity_id
)
# Also check wasGeneratedBy (GraphRAG exploration uses this)
generated = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_GENERATED_BY, entity_id
)
all_children = derived + generated
for child_id in all_children:
child_types = query_triples(
socket, flow_id, user, collection,
s=child_id, p=RDF_TYPE, g=RETRIEVAL_GRAPH
)
for s, p, o in child_types:
if o == TG_ANALYSIS or o == TG_CONCLUSION:
return "agent"
if o == TG_EXPLORATION:
return "graphrag"
# Default to graphrag
return "graphrag"
def build_agent_trace(socket, flow_id, user, collection, session_id, api=None, max_answer=500):
"""Build the full explainability trace for an agent session."""
trace = {
"session_id": session_id,
"type": "agent",
"question": None,
"time": None,
"iterations": [],
"final_answer": None,
}
# Get session metadata
props = get_node_properties(socket, flow_id, user, collection, session_id)
trace["question"] = props.get(TG_QUERY, [None])[0]
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
# Find all entities derived from this session (iterations and final)
# Start by looking for entities where prov:wasDerivedFrom = session_id
current_uri = session_id
iteration_num = 1
while True:
# Find entities derived from current
derived_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, current_uri
)
if not derived_ids:
break
derived_id = derived_ids[0]
derived_props = get_node_properties(socket, flow_id, user, collection, derived_id)
# Check type
types = derived_props.get(RDF_TYPE, [])
if TG_ANALYSIS in types:
iteration = {
"id": derived_id,
"iteration_num": iteration_num,
"thought": derived_props.get(TG_THOUGHT, [None])[0],
"action": derived_props.get(TG_ACTION, [None])[0],
"arguments": derived_props.get(TG_ARGUMENTS, [None])[0],
"observation": derived_props.get(TG_OBSERVATION, [None])[0],
}
trace["iterations"].append(iteration)
current_uri = derived_id
iteration_num += 1
elif TG_CONCLUSION in types:
answer = derived_props.get(TG_ANSWER, [None])[0]
if answer and len(answer) > max_answer:
answer = answer[:max_answer] + "... [truncated]"
trace["final_answer"] = {
"id": derived_id,
"answer": answer,
}
break
else:
# Unknown type, stop traversal
break
return trace
def print_agent_text(trace):
"""Print agent trace in text format."""
print(f"=== Agent Session: {trace['session_id']} ===")
print() print()
if trace["question"]: if question:
print(f"Question: {trace['question']}") print(f"Question: {question.query}")
if trace["time"]: if question.timestamp:
print(f"Time: {trace['time']}") print(f"Time: {question.timestamp}")
print()
# Analysis steps
print("--- Analysis ---")
iterations = trace.get("iterations", [])
if iterations:
for iteration in iterations:
print(f"Analysis {iteration['iteration_num']}:")
print(f" Thought: {iteration.get('thought', 'N/A')}")
print(f" Action: {iteration.get('action', 'N/A')}")
args = iteration.get('arguments')
if args:
# Try to pretty-print JSON arguments
try:
import json
args_obj = json.loads(args)
args_str = json.dumps(args_obj, indent=4)
# Indent each line
args_lines = args_str.split('\n')
print(f" Arguments:")
for line in args_lines:
print(f" {line}")
except:
print(f" Arguments: {args}")
else:
print(f" Arguments: N/A")
obs = iteration.get('observation', 'N/A')
if obs and len(obs) > 200:
obs = obs[:200] + "... [truncated]"
print(f" Observation: {obs}")
print()
else:
print("No analysis steps recorded")
print()
# Conclusion
print("--- Conclusion ---")
final = trace.get("final_answer")
if final and final.get("answer"):
print("Answer:")
for line in final["answer"].split("\n"):
print(f" {line}")
else:
print("No conclusion recorded")
def print_agent_json(trace):
"""Print agent trace as JSON."""
print(json.dumps(trace, indent=2))
def build_trace(socket, flow_id, user, collection, question_id, api=None, show_provenance=False, max_answer=500):
"""Build the full explainability trace for a question."""
label_cache = {}
trace = {
"question_id": question_id,
"question": None,
"time": None,
"exploration": None,
"focus": None,
"synthesis": None,
}
# Get question metadata
props = get_node_properties(socket, flow_id, user, collection, question_id)
trace["question"] = props.get(TG_QUERY, [None])[0]
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
# Find exploration: ?exploration prov:wasGeneratedBy question_id
exploration_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_GENERATED_BY, question_id
)
if exploration_ids:
exploration_id = exploration_ids[0]
exploration_props = get_node_properties(socket, flow_id, user, collection, exploration_id)
trace["exploration"] = {
"id": exploration_id,
"edge_count": exploration_props.get(TG_EDGE_COUNT, [None])[0],
}
# Find focus: ?focus prov:wasDerivedFrom exploration_id
focus_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, exploration_id
)
if focus_ids:
focus_id = focus_ids[0]
focus_props = get_node_properties(socket, flow_id, user, collection, focus_id)
# Get selected edges
edge_selection_uris = focus_props.get(TG_SELECTED_EDGE, [])
selected_edges = []
for edge_sel_uri in edge_selection_uris:
edge_sel_props = get_node_properties(socket, flow_id, user, collection, edge_sel_uri)
edge = edge_sel_props.get(TG_EDGE, [None])[0]
reasoning = edge_sel_props.get(TG_REASONING, [None])[0]
edge_info = {
"edge": edge,
"reasoning": reasoning,
}
# Trace provenance if requested
if show_provenance and isinstance(edge, dict):
provenance = trace_edge_provenance(
socket, flow_id, user, collection,
edge.get("s", ""), edge.get("p", ""), edge.get("o", ""),
label_cache
)
edge_info["provenance"] = provenance
selected_edges.append(edge_info)
trace["focus"] = {
"id": focus_id,
"selected_edges": selected_edges,
}
# Find synthesis: ?synthesis prov:wasDerivedFrom focus_id
synthesis_ids = find_by_predicate_object(
socket, flow_id, user, collection,
PROV_WAS_DERIVED_FROM, focus_id
)
if synthesis_ids:
synthesis_id = synthesis_ids[0]
synthesis_props = get_node_properties(socket, flow_id, user, collection, synthesis_id)
# Get content directly or via document reference
content = synthesis_props.get(TG_CONTENT, [None])[0]
doc_id = synthesis_props.get(TG_DOCUMENT, [None])[0]
if not content and doc_id and api:
content = get_document_content(api, user, doc_id, max_answer)
elif content and len(content) > max_answer:
content = content[:max_answer] + "... [truncated]"
trace["synthesis"] = {
"id": synthesis_id,
"document_id": doc_id,
"answer": content,
}
# Store label cache for formatting
trace["_label_cache"] = label_cache
return trace
def print_text(trace, show_provenance=False):
"""Print trace in text format."""
label_cache = trace.get("_label_cache", {})
print(f"=== GraphRAG Session: {trace['question_id']} ===")
print()
if trace["question"]:
print(f"Question: {trace['question']}")
if trace["time"]:
print(f"Time: {trace['time']}")
print() print()
# Exploration # Exploration
print("--- Exploration ---") print("--- Exploration ---")
exploration = trace.get("exploration") exploration = trace.get("exploration")
if exploration: if exploration:
edge_count = exploration.get("edge_count", "?") print(f"Retrieved {exploration.edge_count} edges from knowledge graph")
print(f"Retrieved {edge_count} edges from knowledge graph")
else: else:
print("No exploration data found") print("No exploration data found")
print() print()
@ -609,24 +188,28 @@ def print_text(trace, show_provenance=False):
print("--- Focus (Edge Selection) ---") print("--- Focus (Edge Selection) ---")
focus = trace.get("focus") focus = trace.get("focus")
if focus: if focus:
edges = focus.get("selected_edges", []) edges = focus.edge_selections
print(f"Selected {len(edges)} edges:") print(f"Selected {len(edges)} edges:")
print() print()
for i, edge_info in enumerate(edges, 1): label_cache = {}
edge = edge_info.get("edge")
reasoning = edge_info.get("reasoning")
if edge: for i, edge_sel in enumerate(edges, 1):
edge_str = format_edge(edge) if edge_sel.edge:
print(f" {i}. {edge_str}") s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, user, collection
)
print(f" {i}. ({s_label}, {p_label}, {o_label})")
if reasoning: if edge_sel.reasoning:
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reasoning: {r_short}") print(f" Reasoning: {r_short}")
if show_provenance: if show_provenance and edge_sel.edge:
provenance = edge_info.get("provenance", []) provenance = trace_edge_provenance(
flow, user, collection, edge_sel.edge,
label_cache, explain_client
)
for chain in provenance: for chain in provenance:
chain_str = format_provenance_chain(chain) chain_str = format_provenance_chain(chain)
if chain_str: if chain_str:
@ -641,11 +224,9 @@ def print_text(trace, show_provenance=False):
print("--- Synthesis ---") print("--- Synthesis ---")
synthesis = trace.get("synthesis") synthesis = trace.get("synthesis")
if synthesis: if synthesis:
answer = synthesis.get("answer") if synthesis.content:
if answer:
print("Answer:") print("Answer:")
# Indent the answer for line in synthesis.content.split("\n"):
for line in answer.split("\n"):
print(f" {line}") print(f" {line}")
else: else:
print("No answer content found") print("No answer content found")
@ -653,11 +234,173 @@ def print_text(trace, show_provenance=False):
print("No synthesis data found") print("No synthesis data found")
def print_json(trace): def print_docrag_text(trace):
"""Print trace as JSON.""" """Print DocRAG trace in text format."""
# Remove internal cache before printing question = trace.get("question")
output = {k: v for k, v in trace.items() if not k.startswith("_")}
print(json.dumps(output, indent=2)) print(f"=== DocRAG Session: {question.uri if question else 'Unknown'} ===")
print()
if question:
print(f"Question: {question.query}")
if question.timestamp:
print(f"Time: {question.timestamp}")
print()
# Exploration
print("--- Exploration ---")
exploration = trace.get("exploration")
if exploration:
print(f"Retrieved {exploration.chunk_count} chunks from document store")
else:
print("No exploration data found")
print()
# Synthesis (no Focus step for DocRAG)
print("--- Synthesis ---")
synthesis = trace.get("synthesis")
if synthesis:
if synthesis.content:
print("Answer:")
for line in synthesis.content.split("\n"):
print(f" {line}")
else:
print("No answer content found")
else:
print("No synthesis data found")
def print_agent_text(trace):
"""Print Agent trace in text format."""
question = trace.get("question")
print(f"=== Agent Session: {question.uri if question else 'Unknown'} ===")
print()
if question:
print(f"Question: {question.query}")
if question.timestamp:
print(f"Time: {question.timestamp}")
print()
# Analysis steps
print("--- Analysis ---")
iterations = trace.get("iterations", [])
if iterations:
for i, analysis in enumerate(iterations, 1):
print(f"Analysis {i}:")
print(f" Thought: {analysis.thought or 'N/A'}")
print(f" Action: {analysis.action or 'N/A'}")
if analysis.arguments:
# Try to pretty-print JSON arguments
try:
args_obj = json.loads(analysis.arguments)
args_str = json.dumps(args_obj, indent=4)
print(f" Arguments:")
for line in args_str.split('\n'):
print(f" {line}")
except Exception:
print(f" Arguments: {analysis.arguments}")
else:
print(f" Arguments: N/A")
obs = analysis.observation or 'N/A'
if obs and len(obs) > 200:
obs = obs[:200] + "... [truncated]"
print(f" Observation: {obs}")
print()
else:
print("No analysis steps recorded")
print()
# Conclusion
print("--- Conclusion ---")
conclusion = trace.get("conclusion")
if conclusion and conclusion.answer:
print("Answer:")
for line in conclusion.answer.split("\n"):
print(f" {line}")
else:
print("No conclusion recorded")
def trace_to_dict(trace, trace_type):
"""Convert trace entities to JSON-serializable dict."""
if trace_type == "agent":
question = trace.get("question")
return {
"type": "agent",
"session_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"iterations": [
{
"id": a.uri,
"thought": a.thought,
"action": a.action,
"arguments": a.arguments,
"observation": a.observation,
}
for a in trace.get("iterations", [])
],
"conclusion": {
"id": trace["conclusion"].uri,
"answer": trace["conclusion"].answer,
} if trace.get("conclusion") else None,
}
elif trace_type == "docrag":
question = trace.get("question")
exploration = trace.get("exploration")
synthesis = trace.get("synthesis")
return {
"type": "docrag",
"question_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"exploration": {
"id": exploration.uri,
"chunk_count": exploration.chunk_count,
} if exploration else None,
"synthesis": {
"id": synthesis.uri,
"document_uri": synthesis.document_uri,
"answer": synthesis.content,
} if synthesis else None,
}
else:
# graphrag
question = trace.get("question")
exploration = trace.get("exploration")
focus = trace.get("focus")
synthesis = trace.get("synthesis")
return {
"type": "graphrag",
"question_id": question.uri if question else None,
"question": question.query if question else None,
"time": question.timestamp if question else None,
"exploration": {
"id": exploration.uri,
"edge_count": exploration.edge_count,
} if exploration else None,
"focus": {
"id": focus.uri,
"selected_edges": [
{
"edge": edge_sel.edge,
"reasoning": edge_sel.reasoning,
}
for edge_sel in focus.edge_selections
],
} if focus else None,
"synthesis": {
"id": synthesis.uri,
"document_uri": synthesis.document_uri,
"answer": synthesis.content,
} if synthesis else None,
}
def main(): def main():
@ -727,50 +470,69 @@ def main():
try: try:
api = Api(args.api_url, token=args.token) api = Api(args.api_url, token=args.token)
socket = api.socket() socket = api.socket()
flow = socket.flow(args.flow_id)
explain_client = ExplainabilityClient(flow)
try: try:
# Detect trace type (agent vs graphrag) # Detect trace type
trace_type = detect_trace_type( trace_type = explain_client.detect_session_type(
socket=socket, args.question_id,
flow_id=args.flow_id, graph=RETRIEVAL_GRAPH,
user=args.user, user=args.user,
collection=args.collection, collection=args.collection,
entity_id=args.question_id,
) )
if trace_type == "agent": if trace_type == "agent":
# Build and print agent trace # Fetch and display agent trace
trace = build_agent_trace( trace = explain_client.fetch_agent_trace(
socket=socket, args.question_id,
flow_id=args.flow_id, graph=RETRIEVAL_GRAPH,
user=args.user, user=args.user,
collection=args.collection, collection=args.collection,
session_id=args.question_id,
api=api, api=api,
max_answer=args.max_answer, max_content=args.max_answer,
) )
if args.format == 'json': if args.format == 'json':
print_agent_json(trace) print(json.dumps(trace_to_dict(trace, "agent"), indent=2))
else: else:
print_agent_text(trace) print_agent_text(trace)
else:
# Build and print GraphRAG trace (existing behavior) elif trace_type == "docrag":
trace = build_trace( # Fetch and display DocRAG trace
socket=socket, trace = explain_client.fetch_docrag_trace(
flow_id=args.flow_id, args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user, user=args.user,
collection=args.collection, collection=args.collection,
question_id=args.question_id,
api=api, api=api,
show_provenance=args.show_provenance, max_content=args.max_answer,
max_answer=args.max_answer,
) )
if args.format == 'json': if args.format == 'json':
print_json(trace) print(json.dumps(trace_to_dict(trace, "docrag"), indent=2))
else: else:
print_text(trace, show_provenance=args.show_provenance) print_docrag_text(trace)
else:
# Fetch and display GraphRAG trace
trace = explain_client.fetch_graphrag_trace(
args.question_id,
graph=RETRIEVAL_GRAPH,
user=args.user,
collection=args.collection,
api=api,
max_content=args.max_answer,
)
if args.format == 'json':
print(json.dumps(trace_to_dict(trace, "graphrag"), indent=2))
else:
print_graphrag_text(
trace, explain_client, flow,
args.user, args.collection,
show_provenance=args.show_provenance
)
finally: finally:
socket.close() socket.close()

View file

@ -2,6 +2,8 @@
Simple agent infrastructure broadly implements the ReAct flow. Simple agent infrastructure broadly implements the ReAct flow.
""" """
import asyncio
import base64
import json import json
import re import re
import sys import sys
@ -17,9 +19,13 @@ from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... base import ProducerSpec from ... base import ProducerSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
# Provenance imports for agent explainability # Provenance imports for agent explainability
from trustgraph.provenance import ( from trustgraph.provenance import (
@ -41,6 +47,8 @@ from . types import Final, Action, Tool, Argument
default_ident = "agent-manager" default_ident = "agent-manager"
default_max_iterations = 10 default_max_iterations = 10
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(AgentService): class Processor(AgentService):
@ -129,6 +137,115 @@ class Processor(AgentService):
) )
) )
# Librarian client for storing answer content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_tools_config(self, config, version): async def on_tools_config(self, config, version):
logger.info(f"Loading configuration version {version}") logger.info(f"Loading configuration version {version}")
@ -347,6 +464,15 @@ class Processor(AgentService):
)) ))
logger.debug(f"Emitted session triples for {session_uri}") logger.debug(f"Emitted session triples for {session_uri}")
# Send explain event for session
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=session_uri,
explain_graph=GRAPH_RETRIEVAL,
))
logger.info(f"Question: {request.question}") logger.info(f"Question: {request.question}")
if len(history) >= self.max_iterations: if len(history) >= self.max_iterations:
@ -504,8 +630,28 @@ class Processor(AgentService):
else: else:
parent_uri = session_uri parent_uri = session_uri
# Save answer to librarian
answer_doc_id = None
if f:
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
try:
await self.save_answer_content(
doc_id=answer_doc_id,
user=request.user,
content=f,
title=f"Agent Answer: {request.question[:50]}...",
)
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
answer_doc_id = None # Fall back to inline content
final_triples = set_graph( final_triples = set_graph(
agent_final_triples(final_uri, parent_uri, f), agent_final_triples(
final_uri, parent_uri,
answer="" if answer_doc_id else f,
document_id=answer_doc_id,
),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )
await flow("explainability").send(Triples( await flow("explainability").send(Triples(
@ -518,6 +664,15 @@ class Processor(AgentService):
)) ))
logger.debug(f"Emitted final triples for {final_uri}") logger.debug(f"Emitted final triples for {final_uri}")
# Send explain event for conclusion
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=final_uri,
explain_graph=GRAPH_RETRIEVAL,
))
if streaming: if streaming:
# Streaming format - send end-of-dialog marker # Streaming format - send end-of-dialog marker
# Answer chunks were already sent via answer() callback during parsing # Answer chunks were already sent via answer() callback during parsing
@ -558,14 +713,48 @@ class Processor(AgentService):
else: else:
parent_uri = session_uri parent_uri = session_uri
# Save thought to librarian
thought_doc_id = None
if act.thought:
thought_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
try:
await self.save_answer_content(
doc_id=thought_doc_id,
user=request.user,
content=act.thought,
title=f"Agent Thought: {act.name}",
)
logger.debug(f"Saved thought to librarian: {thought_doc_id}")
except Exception as e:
logger.warning(f"Failed to save thought to librarian: {e}")
thought_doc_id = None
# Save observation to librarian
observation_doc_id = None
if act.observation:
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
try:
await self.save_answer_content(
doc_id=observation_doc_id,
user=request.user,
content=act.observation,
title=f"Agent Observation: {act.name}",
)
logger.debug(f"Saved observation to librarian: {observation_doc_id}")
except Exception as e:
logger.warning(f"Failed to save observation to librarian: {e}")
observation_doc_id = None
iter_triples = set_graph( iter_triples = set_graph(
agent_iteration_triples( agent_iteration_triples(
iteration_uri, iteration_uri,
parent_uri, parent_uri,
act.thought, thought="" if thought_doc_id else act.thought,
act.name, action=act.name,
act.arguments, arguments=act.arguments,
act.observation, observation="" if observation_doc_id else act.observation,
thought_document_id=thought_doc_id,
observation_document_id=observation_doc_id,
), ),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )
@ -579,6 +768,15 @@ class Processor(AgentService):
)) ))
logger.debug(f"Emitted iteration triples for {iteration_uri}") logger.debug(f"Emitted iteration triples for {iteration_uri}")
# Send explain event for iteration
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=iteration_uri,
explain_graph=GRAPH_RETRIEVAL,
))
history.append(act) history.append(act)
# Handle state transitions if tool execution was successful # Handle state transitions if tool execution was successful

View file

@ -109,7 +109,7 @@ class DocumentRag:
async def query( async def query(
self, query, user="trustgraph", collection="default", self, query, user="trustgraph", collection="default",
doc_limit=20, streaming=False, chunk_callback=None, doc_limit=20, streaming=False, chunk_callback=None,
explain_callback=None, explain_callback=None, save_answer_callback=None,
): ):
""" """
Execute a Document RAG query with optional explainability tracking. Execute a Document RAG query with optional explainability tracking.
@ -122,6 +122,7 @@ class DocumentRag:
streaming: Enable streaming LLM response streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for explainability explain_callback: async def callback(triples, explain_id) for explainability
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns: Returns:
str: The synthesized answer text str: The synthesized answer text
@ -192,9 +193,28 @@ class DocumentRag:
# Emit synthesis explainability after answer generated # Emit synthesis explainability after answer generated
if explain_callback: if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else "" answer_text = resp if resp else ""
# Save answer to librarian if callback provided
if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None # Fall back to inline content
# Generate triples with document reference or inline content
syn_triples = set_graph( syn_triples = set_graph(
docrag_synthesis_triples(syn_uri, exp_uri, answer_text), docrag_synthesis_triples(
syn_uri, exp_uri,
answer_text="" if synthesis_doc_id else answer_text,
document_id=synthesis_doc_id,
),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )
await explain_callback(syn_triples, syn_uri) await explain_callback(syn_triples, syn_uri)

View file

@ -8,8 +8,10 @@ import asyncio
import base64 import base64
import logging import logging
import uuid
from ... schema import DocumentRagQuery, DocumentRagResponse, Error from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from ... schema import LibrarianRequest, LibrarianResponse from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... provenance import GRAPH_RETRIEVAL from ... provenance import GRAPH_RETRIEVAL
@ -179,6 +181,62 @@ class Processor(FlowProcessor):
self.pending_requests.pop(request_id, None) self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching chunk {chunk_id}") raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "DocumentRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow): async def on_request(self, msg, consumer, flow):
try: try:
@ -222,10 +280,20 @@ class Processor(FlowProcessor):
response=None, response=None,
explain_id=explain_id, explain_id=explain_id,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
message_type="explain",
), ),
properties={"id": id} properties={"id": id}
) )
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
content=answer_text,
title=f"DocumentRAG Answer: {v.query[:50]}...",
)
# Check if streaming is requested # Check if streaming is requested
if v.streaming: if v.streaming:
# Define async callback for streaming chunks # Define async callback for streaming chunks
@ -235,6 +303,7 @@ class Processor(FlowProcessor):
DocumentRagResponse( DocumentRagResponse(
response=chunk, response=chunk,
end_of_stream=end_of_stream, end_of_stream=end_of_stream,
message_type="chunk",
error=None error=None
), ),
properties={"id": id} properties={"id": id}
@ -250,6 +319,17 @@ class Processor(FlowProcessor):
streaming=True, streaming=True,
chunk_callback=send_chunk, chunk_callback=send_chunk,
explain_callback=send_explainability, explain_callback=send_explainability,
save_answer_callback=save_answer,
)
# Send end_of_session to signal entire session is complete
await flow("response").send(
DocumentRagResponse(
response=None,
end_of_session=True,
message_type="end",
),
properties={"id": id}
) )
else: else:
# Non-streaming path (existing behavior) # Non-streaming path (existing behavior)
@ -259,6 +339,7 @@ class Processor(FlowProcessor):
collection=v.collection, collection=v.collection,
doc_limit=doc_limit, doc_limit=doc_limit,
explain_callback=send_explainability, explain_callback=send_explainability,
save_answer_callback=save_answer,
) )
await flow("response").send( await flow("response").send(