mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-08 22:35:14 +02:00
Add unified explainability support and librarian storage for (#693)
Add unified explainability support and librarian storage for all retrieval engines Implements consistent explainability/provenance tracking across GraphRAG, DocumentRAG, and Agent retrieval engines. All large content (answers, thoughts, observations) is now stored in librarian rather than as inline literals in the knowledge graph. Explainability API: - New explainability.py module with entity classes (Question, Exploration, Focus, Synthesis, Analysis, Conclusion) and ExplainabilityClient - Quiescence-based eventual consistency handling for trace fetching - Content fetching from librarian with retry logic CLI updates: - tg-invoke-graph-rag -x/--explainable flag returns explain_id - tg-invoke-document-rag -x/--explainable flag returns explain_id - tg-invoke-agent -x/--explainable flag returns explain_id - tg-list-explain-traces uses new explainability API - tg-show-explain-trace handles all three trace types Agent provenance: - Records session, iterations (think/act/observe), and conclusion - Stores thoughts and observations in librarian with document references - New predicates: tg:thoughtDocument, tg:observationDocument DocumentRAG provenance: - Records question, exploration (chunk retrieval), and synthesis - Stores answers in librarian with document references Schema changes: - AgentResponse: added explain_id, explain_graph fields - RetrievalResponse: added explain_id, explain_graph fields - agent_iteration_triples: supports thought_document_id, observation_document_id Update tests.
This commit is contained in:
parent
aecf00f040
commit
35128ff019
24 changed files with 2736 additions and 846 deletions
|
|
@ -110,16 +110,17 @@ class TestRAGTranslatorCompletionFlags:
|
|||
assert response_dict["end_of_stream"] is True
|
||||
assert response_dict["end_of_session"] is False
|
||||
|
||||
def test_document_rag_translator_is_final_with_end_of_stream_true(self):
|
||||
def test_document_rag_translator_is_final_with_end_of_session_true(self):
|
||||
"""
|
||||
Test that DocumentRagResponseTranslator returns is_final=True
|
||||
when end_of_stream=True.
|
||||
when end_of_session=True.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("document-rag")
|
||||
response = DocumentRagResponse(
|
||||
response="A document about cats.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
|
|
@ -127,9 +128,31 @@ class TestRAGTranslatorCompletionFlags:
|
|||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_stream=True"
|
||||
assert is_final is True, "is_final must be True when end_of_session=True"
|
||||
assert response_dict["response"] == "A document about cats."
|
||||
assert response_dict["end_of_session"] is True
|
||||
|
||||
def test_document_rag_translator_end_of_stream_not_final(self):
|
||||
"""
|
||||
Test that end_of_stream=True alone does NOT make is_final=True.
|
||||
The session continues with provenance messages after LLM stream completes.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("document-rag")
|
||||
response = DocumentRagResponse(
|
||||
response="Final chunk",
|
||||
end_of_stream=True,
|
||||
end_of_session=False, # Session continues with provenance
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
|
||||
assert response_dict["end_of_stream"] is True
|
||||
assert response_dict["end_of_session"] is False
|
||||
|
||||
def test_document_rag_translator_is_final_with_end_of_stream_false(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -30,10 +30,13 @@ class TestAgentStructuredQueryIntegration:
|
|||
pulsar_client=AsyncMock(),
|
||||
max_iterations=3
|
||||
)
|
||||
|
||||
|
||||
# Mock the client method for structured query
|
||||
proc.client = MagicMock()
|
||||
|
||||
|
||||
# Mock librarian to avoid hanging on save operations
|
||||
proc.save_answer_content = AsyncMock(return_value=None)
|
||||
|
||||
return proc
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ class TestAgentServiceNonStreaming:
|
|||
max_iterations=10
|
||||
)
|
||||
|
||||
# Mock librarian to avoid hanging on save operations
|
||||
processor.save_answer_content = AsyncMock(return_value=None)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
|
|
@ -106,6 +109,9 @@ class TestAgentServiceNonStreaming:
|
|||
max_iterations=10
|
||||
)
|
||||
|
||||
# Mock librarian to avoid hanging on save operations
|
||||
processor.save_answer_content = AsyncMock(return_value=None)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
|
|
@ -173,6 +179,9 @@ class TestAgentServiceNonStreaming:
|
|||
max_iterations=10
|
||||
)
|
||||
|
||||
# Mock librarian to avoid hanging on save operations
|
||||
processor.save_answer_content = AsyncMock(return_value=None)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class TestDocumentRagService:
|
|||
collection="test_coll_1", # Must be from message, not hardcoded default
|
||||
doc_limit=5,
|
||||
explain_callback=ANY, # Explainability callback is always passed
|
||||
save_answer_callback=ANY, # Librarian save callback is always passed
|
||||
)
|
||||
|
||||
# Verify response was sent
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ from .flow import Flow, FlowInstance
|
|||
from .async_flow import AsyncFlow, AsyncFlowInstance
|
||||
|
||||
# WebSocket clients
|
||||
from .socket_client import SocketClient, SocketFlowInstance
|
||||
from .socket_client import SocketClient, SocketFlowInstance, build_term
|
||||
from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance
|
||||
|
||||
# Bulk operation clients
|
||||
|
|
@ -70,6 +70,21 @@ from .async_bulk_client import AsyncBulkClient
|
|||
from .metrics import Metrics
|
||||
from .async_metrics import AsyncMetrics
|
||||
|
||||
# Explainability
|
||||
from .explainability import (
|
||||
ExplainabilityClient,
|
||||
ExplainEntity,
|
||||
Question,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
EdgeSelection,
|
||||
wire_triples_to_tuples,
|
||||
extract_term_value,
|
||||
)
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
Triple,
|
||||
|
|
@ -85,6 +100,7 @@ from .types import (
|
|||
AgentObservation,
|
||||
AgentAnswer,
|
||||
RAGChunk,
|
||||
ProvenanceEvent,
|
||||
)
|
||||
|
||||
# Exceptions
|
||||
|
|
@ -124,6 +140,7 @@ __all__ = [
|
|||
"SocketFlowInstance",
|
||||
"AsyncSocketClient",
|
||||
"AsyncSocketFlowInstance",
|
||||
"build_term",
|
||||
|
||||
# Bulk operation clients
|
||||
"BulkClient",
|
||||
|
|
@ -133,6 +150,19 @@ __all__ = [
|
|||
"Metrics",
|
||||
"AsyncMetrics",
|
||||
|
||||
# Explainability
|
||||
"ExplainabilityClient",
|
||||
"ExplainEntity",
|
||||
"Question",
|
||||
"Exploration",
|
||||
"Focus",
|
||||
"Synthesis",
|
||||
"Analysis",
|
||||
"Conclusion",
|
||||
"EdgeSelection",
|
||||
"wire_triples_to_tuples",
|
||||
"extract_term_value",
|
||||
|
||||
# Types
|
||||
"Triple",
|
||||
"Uri",
|
||||
|
|
@ -147,6 +177,7 @@ __all__ = [
|
|||
"AgentObservation",
|
||||
"AgentAnswer",
|
||||
"RAGChunk",
|
||||
"ProvenanceEvent",
|
||||
|
||||
# Exceptions
|
||||
"ProtocolException",
|
||||
|
|
|
|||
1132
trustgraph-base/trustgraph/api/explainability.py
Normal file
1132
trustgraph-base/trustgraph/api/explainability.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -15,6 +15,63 @@ from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, Strea
|
|||
from . exceptions import ProtocolException, raise_from_error_dict
|
||||
|
||||
|
||||
def build_term(value: Any, term_type: Optional[str] = None,
|
||||
datatype: Optional[str] = None, language: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Build wire-format Term dict from a value.
|
||||
|
||||
Auto-detection rules (when term_type is None):
|
||||
- Already a dict with 't' key -> return as-is (already a Term)
|
||||
- Starts with http://, https://, urn: -> IRI
|
||||
- Wrapped in <> (e.g., <http://...>) -> IRI (angle brackets stripped)
|
||||
- Anything else -> literal
|
||||
|
||||
Args:
|
||||
value: The term value (string, dict, or None)
|
||||
term_type: One of 'iri', 'literal', or None for auto-detect
|
||||
datatype: Datatype for literal objects (e.g., xsd:integer)
|
||||
language: Language tag for literal objects (e.g., en)
|
||||
|
||||
Returns:
|
||||
dict: Wire-format Term dict, or None if value is None
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# If already a Term dict, return as-is
|
||||
if isinstance(value, dict) and "t" in value:
|
||||
return value
|
||||
|
||||
# Convert to string for processing
|
||||
value = str(value)
|
||||
|
||||
# Auto-detect type if not specified
|
||||
if term_type is None:
|
||||
if value.startswith("<") and value.endswith(">") and not value.startswith("<<"):
|
||||
# Angle-bracket wrapped IRI: <http://...>
|
||||
value = value[1:-1] # Strip < and >
|
||||
term_type = "iri"
|
||||
elif value.startswith(("http://", "https://", "urn:")):
|
||||
term_type = "iri"
|
||||
else:
|
||||
term_type = "literal"
|
||||
|
||||
if term_type == "iri":
|
||||
# Strip angle brackets if present
|
||||
if value.startswith("<") and value.endswith(">"):
|
||||
value = value[1:-1]
|
||||
return {"t": "i", "i": value}
|
||||
elif term_type == "literal":
|
||||
result = {"t": "l", "v": value}
|
||||
if datatype:
|
||||
result["dt"] = datatype
|
||||
if language:
|
||||
result["ln"] = language
|
||||
return result
|
||||
else:
|
||||
raise ValueError(f"Unknown term type: {term_type}")
|
||||
|
||||
|
||||
class SocketClient:
|
||||
"""
|
||||
Synchronous WebSocket client for streaming operations.
|
||||
|
|
@ -92,7 +149,8 @@ class SocketClient:
|
|||
flow: Optional[str],
|
||||
request: Dict[str, Any],
|
||||
streaming: bool = False,
|
||||
streaming_raw: bool = False
|
||||
streaming_raw: bool = False,
|
||||
include_provenance: bool = False
|
||||
) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]:
|
||||
"""Synchronous wrapper around async WebSocket communication.
|
||||
|
||||
|
|
@ -119,7 +177,7 @@ class SocketClient:
|
|||
return self._streaming_generator_raw(service, flow, request, loop)
|
||||
elif streaming:
|
||||
# Parsed streaming for agent/RAG chunk types
|
||||
return self._streaming_generator(service, flow, request, loop)
|
||||
return self._streaming_generator(service, flow, request, loop, include_provenance)
|
||||
else:
|
||||
# Non-streaming single response
|
||||
return loop.run_until_complete(self._send_request_async(service, flow, request))
|
||||
|
|
@ -129,10 +187,11 @@ class SocketClient:
|
|||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any],
|
||||
loop: asyncio.AbstractEventLoop
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
include_provenance: bool = False
|
||||
) -> Iterator[StreamingChunk]:
|
||||
"""Generator that yields streaming chunks (for agent/RAG responses)"""
|
||||
async_gen = self._send_request_async_streaming(service, flow, request)
|
||||
async_gen = self._send_request_async_streaming(service, flow, request, include_provenance)
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
|
@ -265,7 +324,8 @@ class SocketClient:
|
|||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any]
|
||||
request: Dict[str, Any],
|
||||
include_provenance: bool = False
|
||||
) -> Iterator[StreamingChunk]:
|
||||
"""Async implementation of WebSocket request (streaming)"""
|
||||
# Generate unique request ID
|
||||
|
|
@ -309,8 +369,8 @@ class SocketClient:
|
|||
raise_from_error_dict(resp["error"])
|
||||
|
||||
# Parse different chunk types
|
||||
chunk = self._parse_chunk(resp)
|
||||
if chunk is not None: # Skip provenance messages in streaming
|
||||
chunk = self._parse_chunk(resp, include_provenance=include_provenance)
|
||||
if chunk is not None: # Skip provenance messages unless include_provenance
|
||||
yield chunk
|
||||
|
||||
# Check if this is the final message
|
||||
|
|
@ -325,14 +385,26 @@ class SocketClient:
|
|||
chunk_type = resp.get("chunk_type")
|
||||
message_type = resp.get("message_type")
|
||||
|
||||
# Handle new GraphRAG message format with message_type
|
||||
if message_type == "provenance":
|
||||
# Handle GraphRAG/DocRAG message format with message_type
|
||||
if message_type == "explain":
|
||||
if include_provenance:
|
||||
# Return provenance event for explainability
|
||||
return ProvenanceEvent(provenance_id=resp.get("provenance_id", ""))
|
||||
return ProvenanceEvent(
|
||||
explain_id=resp.get("explain_id", ""),
|
||||
explain_graph=resp.get("explain_graph", "")
|
||||
)
|
||||
# Provenance messages are not yielded to user - they're metadata
|
||||
return None
|
||||
|
||||
# Handle Agent message format with chunk_type="explain"
|
||||
if chunk_type == "explain":
|
||||
if include_provenance:
|
||||
return ProvenanceEvent(
|
||||
explain_id=resp.get("explain_id", ""),
|
||||
explain_graph=resp.get("explain_graph", "")
|
||||
)
|
||||
return None
|
||||
|
||||
if chunk_type == "thought":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
|
|
@ -477,6 +549,95 @@ class SocketFlowInstance:
|
|||
# regardless of streaming flag, so always use the streaming code path
|
||||
return self.client._send_request_sync("agent", self.flow_id, request, streaming=True)
|
||||
|
||||
def agent_explain(
|
||||
self,
|
||||
question: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
group: Optional[str] = None,
|
||||
history: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[StreamingChunk, ProvenanceEvent]]:
|
||||
"""
|
||||
Execute an agent operation with explainability support.
|
||||
|
||||
Streams both content chunks (AgentThought, AgentObservation, AgentAnswer)
|
||||
and provenance events (ProvenanceEvent). Provenance events contain URIs
|
||||
that can be fetched using ExplainabilityClient to get detailed information
|
||||
about the agent's reasoning process.
|
||||
|
||||
Agent trace consists of:
|
||||
- Session: The initial question and session metadata
|
||||
- Iterations: Each thought/action/observation cycle
|
||||
- Conclusion: The final answer
|
||||
|
||||
Args:
|
||||
question: User question or instruction
|
||||
user: User identifier
|
||||
collection: Collection identifier for provenance storage
|
||||
state: Optional state dictionary for stateful conversations
|
||||
group: Optional group identifier for multi-user contexts
|
||||
history: Optional conversation history as list of message dicts
|
||||
**kwargs: Additional parameters passed to the agent service
|
||||
|
||||
Yields:
|
||||
Union[StreamingChunk, ProvenanceEvent]: Agent chunks and provenance events
|
||||
|
||||
Example:
|
||||
```python
|
||||
from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent
|
||||
from trustgraph.api import AgentThought, AgentObservation, AgentAnswer
|
||||
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
provenance_ids = []
|
||||
for item in flow.agent_explain(
|
||||
question="What is the capital of France?",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
):
|
||||
if isinstance(item, AgentThought):
|
||||
print(f"[Thought] {item.content}")
|
||||
elif isinstance(item, AgentObservation):
|
||||
print(f"[Observation] {item.content}")
|
||||
elif isinstance(item, AgentAnswer):
|
||||
print(f"[Answer] {item.content}")
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
provenance_ids.append(item.explain_id)
|
||||
|
||||
# Fetch session trace after completion
|
||||
if provenance_ids:
|
||||
trace = explain_client.fetch_agent_trace(
|
||||
provenance_ids[0], # Session URI is first
|
||||
graph="urn:graph:retrieval",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"streaming": True # Always streaming for explain
|
||||
}
|
||||
if state is not None:
|
||||
request["state"] = state
|
||||
if group is not None:
|
||||
request["group"] = group
|
||||
if history is not None:
|
||||
request["history"] = history
|
||||
request.update(kwargs)
|
||||
|
||||
# Use streaming with provenance enabled
|
||||
return self.client._send_request_sync(
|
||||
"agent", self.flow_id, request,
|
||||
streaming=True, include_provenance=True
|
||||
)
|
||||
|
||||
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
|
||||
"""
|
||||
Execute text completion with optional streaming.
|
||||
|
|
@ -596,6 +757,86 @@ class SocketFlowInstance:
|
|||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def graph_rag_explain(
|
||||
self,
|
||||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
max_subgraph_size: int = 1000,
|
||||
max_subgraph_count: int = 5,
|
||||
max_entity_distance: int = 3,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""
|
||||
Execute graph-based RAG query with explainability support.
|
||||
|
||||
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
|
||||
Provenance events contain URIs that can be fetched using ExplainabilityClient
|
||||
to get detailed information about how the response was generated.
|
||||
|
||||
Args:
|
||||
query: Natural language query
|
||||
user: User/keyspace identifier
|
||||
collection: Collection identifier
|
||||
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
|
||||
max_subgraph_count: Maximum number of subgraphs (default: 5)
|
||||
max_entity_distance: Maximum traversal depth (default: 3)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Yields:
|
||||
Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events
|
||||
|
||||
Example:
|
||||
```python
|
||||
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
|
||||
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
provenance_ids = []
|
||||
response_text = ""
|
||||
|
||||
for item in flow.graph_rag_explain(
|
||||
query="Tell me about Marie Curie",
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
response_text += item.content
|
||||
print(item.content, end='', flush=True)
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
provenance_ids.append(item.provenance_id)
|
||||
|
||||
# Fetch explainability details
|
||||
for prov_id in provenance_ids:
|
||||
entity = explain_client.fetch_entity(
|
||||
prov_id,
|
||||
graph="urn:graph:retrieval",
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
)
|
||||
print(f"Entity: {entity}")
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-subgraph-count": max_subgraph_count,
|
||||
"max-entity-distance": max_entity_distance,
|
||||
"streaming": True,
|
||||
"explainable": True, # Enable explainability mode
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
# Use streaming with provenance events included
|
||||
return self.client._send_request_sync(
|
||||
"graph-rag", self.flow_id, request,
|
||||
streaming=True, include_provenance=True
|
||||
)
|
||||
|
||||
def document_rag(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -654,6 +895,79 @@ class SocketFlowInstance:
|
|||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def document_rag_explain(
|
||||
self,
|
||||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""
|
||||
Execute document-based RAG query with explainability support.
|
||||
|
||||
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
|
||||
Provenance events contain URIs that can be fetched using ExplainabilityClient
|
||||
to get detailed information about how the response was generated.
|
||||
|
||||
Document RAG trace consists of:
|
||||
- Question: The user's query
|
||||
- Exploration: Chunks retrieved from document store (chunk_count)
|
||||
- Synthesis: The generated answer
|
||||
|
||||
Args:
|
||||
query: Natural language query
|
||||
user: User/keyspace identifier
|
||||
collection: Collection identifier
|
||||
doc_limit: Maximum document chunks to retrieve (default: 10)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Yields:
|
||||
Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events
|
||||
|
||||
Example:
|
||||
```python
|
||||
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
|
||||
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
for item in flow.document_rag_explain(
|
||||
query="Summarize the key findings",
|
||||
user="trustgraph",
|
||||
collection="research-papers",
|
||||
doc_limit=5
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
print(item.content, end='', flush=True)
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
# Fetch entity details
|
||||
entity = explain_client.fetch_entity(
|
||||
item.explain_id,
|
||||
graph=item.explain_graph,
|
||||
user="trustgraph",
|
||||
collection="research-papers"
|
||||
)
|
||||
print(f"Event: {entity}", file=sys.stderr)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"streaming": True,
|
||||
"explainable": True,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
# Use streaming with provenance events included
|
||||
return self.client._send_request_sync(
|
||||
"document-rag", self.flow_id, request,
|
||||
streaming=True, include_provenance=True
|
||||
)
|
||||
|
||||
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
"""Generator for RAG streaming (graph-rag and document-rag)"""
|
||||
for chunk in result:
|
||||
|
|
@ -831,28 +1145,30 @@ class SocketFlowInstance:
|
|||
|
||||
def triples_query(
|
||||
self,
|
||||
s: Optional[str] = None,
|
||||
p: Optional[str] = None,
|
||||
o: Optional[str] = None,
|
||||
s: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
p: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
o: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
g: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query knowledge graph triples using pattern matching.
|
||||
|
||||
Args:
|
||||
s: Subject URI (optional, use None for wildcard)
|
||||
p: Predicate URI (optional, use None for wildcard)
|
||||
o: Object URI or Literal (optional, use None for wildcard)
|
||||
s: Subject filter - URI string, Term dict, or None for wildcard
|
||||
p: Predicate filter - URI string, Term dict, or None for wildcard
|
||||
o: Object filter - URI/literal string, Term dict, or None for wildcard
|
||||
g: Named graph filter - URI string or None for all graphs
|
||||
user: User/keyspace identifier (optional)
|
||||
collection: Collection identifier (optional)
|
||||
limit: Maximum results to return (default: 100)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Returns:
|
||||
dict: Query results with matching triples
|
||||
List[Dict]: List of matching triples in wire format
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
|
@ -860,33 +1176,54 @@ class SocketFlowInstance:
|
|||
flow = socket.flow("default")
|
||||
|
||||
# Find all triples about a specific subject
|
||||
result = flow.triples_query(
|
||||
triples = flow.triples_query(
|
||||
s="http://example.org/person/marie-curie",
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
)
|
||||
|
||||
# Query with named graph filter
|
||||
triples = flow.triples_query(
|
||||
s="urn:trustgraph:session:abc123",
|
||||
g="urn:graph:retrieval",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
```
|
||||
"""
|
||||
request = {"limit": limit}
|
||||
if s is not None:
|
||||
request["s"] = str(s)
|
||||
if p is not None:
|
||||
request["p"] = str(p)
|
||||
if o is not None:
|
||||
request["o"] = str(o)
|
||||
|
||||
# Build Term dicts for s/p/o (auto-converts strings)
|
||||
s_term = build_term(s)
|
||||
p_term = build_term(p)
|
||||
o_term = build_term(o)
|
||||
|
||||
if s_term is not None:
|
||||
request["s"] = s_term
|
||||
if p_term is not None:
|
||||
request["p"] = p_term
|
||||
if o_term is not None:
|
||||
request["o"] = o_term
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
if user is not None:
|
||||
request["user"] = user
|
||||
if collection is not None:
|
||||
request["collection"] = collection
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||
result = self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||
# Return the triples list from the response
|
||||
if isinstance(result, dict) and "response" in result:
|
||||
return result["response"]
|
||||
return result
|
||||
|
||||
def triples_query_stream(
|
||||
self,
|
||||
s: Optional[str] = None,
|
||||
p: Optional[str] = None,
|
||||
o: Optional[str] = None,
|
||||
s: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
p: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
o: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
g: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
|
|
@ -900,9 +1237,10 @@ class SocketFlowInstance:
|
|||
and memory overhead for large result sets.
|
||||
|
||||
Args:
|
||||
s: Subject URI (optional, use None for wildcard)
|
||||
p: Predicate URI (optional, use None for wildcard)
|
||||
o: Object URI or Literal (optional, use None for wildcard)
|
||||
s: Subject filter - URI string, Term dict, or None for wildcard
|
||||
p: Predicate filter - URI string, Term dict, or None for wildcard
|
||||
o: Object filter - URI/literal string, Term dict, or None for wildcard
|
||||
g: Named graph filter - URI string or None for all graphs
|
||||
user: User/keyspace identifier (optional)
|
||||
collection: Collection identifier (optional)
|
||||
limit: Maximum results to return (default: 100)
|
||||
|
|
@ -930,12 +1268,20 @@ class SocketFlowInstance:
|
|||
"streaming": True,
|
||||
"batch-size": batch_size,
|
||||
}
|
||||
if s is not None:
|
||||
request["s"] = str(s)
|
||||
if p is not None:
|
||||
request["p"] = str(p)
|
||||
if o is not None:
|
||||
request["o"] = str(o)
|
||||
|
||||
# Build Term dicts for s/p/o (auto-converts strings)
|
||||
s_term = build_term(s)
|
||||
p_term = build_term(p)
|
||||
o_term = build_term(o)
|
||||
|
||||
if s_term is not None:
|
||||
request["s"] = s_term
|
||||
if p_term is not None:
|
||||
request["p"] = p_term
|
||||
if o_term is not None:
|
||||
request["o"] = o_term
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
if user is not None:
|
||||
request["user"] = user
|
||||
if collection is not None:
|
||||
|
|
|
|||
|
|
@ -212,19 +212,21 @@ class ProvenanceEvent:
|
|||
Each event represents a provenance node created during query processing.
|
||||
|
||||
Attributes:
|
||||
provenance_id: URI of the provenance node (e.g., urn:trustgraph:session:abc123)
|
||||
event_type: Type of provenance event (session, retrieval, selection, answer)
|
||||
explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123)
|
||||
explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval)
|
||||
event_type: Type of provenance event (question, exploration, focus, synthesis)
|
||||
"""
|
||||
provenance_id: str
|
||||
event_type: str = "" # Derived from provenance_id (session, retrieval, selection, answer)
|
||||
explain_id: str
|
||||
explain_graph: str = ""
|
||||
event_type: str = "" # Derived from explain_id
|
||||
|
||||
def __post_init__(self):
|
||||
# Extract event type from provenance_id
|
||||
if "session" in self.provenance_id:
|
||||
self.event_type = "session"
|
||||
elif "retrieval" in self.provenance_id:
|
||||
self.event_type = "retrieval"
|
||||
elif "selection" in self.provenance_id:
|
||||
self.event_type = "selection"
|
||||
elif "answer" in self.provenance_id:
|
||||
self.event_type = "answer"
|
||||
# Extract event type from explain_id
|
||||
if "question" in self.explain_id:
|
||||
self.event_type = "question"
|
||||
elif "exploration" in self.explain_id:
|
||||
self.event_type = "exploration"
|
||||
elif "focus" in self.explain_id:
|
||||
self.event_type = "focus"
|
||||
elif "synthesis" in self.explain_id:
|
||||
self.event_type = "synthesis"
|
||||
|
|
|
|||
|
|
@ -59,6 +59,15 @@ class AgentResponseTranslator(MessageTranslator):
|
|||
result["end_of_message"] = getattr(obj, "end_of_message", False)
|
||||
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
|
||||
|
||||
# Include explainability fields if present
|
||||
explain_id = getattr(obj, "explain_id", None)
|
||||
if explain_id:
|
||||
result["explain_id"] = explain_id
|
||||
|
||||
explain_graph = getattr(obj, "explain_graph", None)
|
||||
if explain_graph is not None:
|
||||
result["explain_graph"] = explain_graph
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "code": obj.error.code}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,12 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
# Include response content (even if empty string)
|
||||
# Include message_type for distinguishing chunk vs explain messages
|
||||
message_type = getattr(obj, "message_type", "")
|
||||
if message_type:
|
||||
result["message_type"] = message_type
|
||||
|
||||
# Include response content for chunk messages
|
||||
if obj.response is not None:
|
||||
result["response"] = obj.response
|
||||
|
||||
|
|
@ -48,9 +53,12 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
if explain_graph is not None:
|
||||
result["explain_graph"] = explain_graph
|
||||
|
||||
# Include end_of_stream flag
|
||||
# Include end_of_stream flag (LLM stream complete)
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
# Include end_of_session flag (entire session complete)
|
||||
result["end_of_session"] = getattr(obj, "end_of_session", False)
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||
|
|
@ -59,7 +67,8 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
is_final = getattr(obj, 'end_of_stream', False)
|
||||
# Session is complete when end_of_session is True
|
||||
is_final = getattr(obj, 'end_of_session', False)
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,10 @@ from . namespaces import (
|
|||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
|
||||
# Agent provenance predicates
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
|
||||
# Agent document references
|
||||
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
|
||||
# Document reference predicate
|
||||
TG_DOCUMENT,
|
||||
# Named graphs
|
||||
GRAPH_DEFAULT, GRAPH_SOURCE, GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
|
@ -165,6 +169,10 @@ __all__ = [
|
|||
"TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION",
|
||||
# Agent provenance predicates
|
||||
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER",
|
||||
# Agent document references
|
||||
"TG_THOUGHT_DOCUMENT", "TG_OBSERVATION_DOCUMENT",
|
||||
# Document reference predicate
|
||||
"TG_DOCUMENT",
|
||||
# Named graphs
|
||||
"GRAPH_DEFAULT", "GRAPH_SOURCE", "GRAPH_RETRIEVAL",
|
||||
# Triple builders
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ from . namespaces import (
|
|||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
|
||||
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
|
||||
TG_AGENT_QUESTION,
|
||||
)
|
||||
|
||||
|
|
@ -73,10 +74,12 @@ def agent_session_triples(
|
|||
def agent_iteration_triples(
|
||||
iteration_uri: str,
|
||||
parent_uri: str,
|
||||
thought: str,
|
||||
action: str,
|
||||
arguments: Dict[str, Any],
|
||||
observation: str,
|
||||
thought: str = "",
|
||||
action: str = "",
|
||||
arguments: Dict[str, Any] = None,
|
||||
observation: str = "",
|
||||
thought_document_id: Optional[str] = None,
|
||||
observation_document_id: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for one agent iteration (Analysis - think/act/observe cycle).
|
||||
|
|
@ -85,36 +88,53 @@ def agent_iteration_triples(
|
|||
- Entity declaration with tg:Analysis type
|
||||
- wasDerivedFrom link to parent (previous iteration or session)
|
||||
- Thought, action, arguments, and observation data
|
||||
- Document references for thought/observation when stored in librarian
|
||||
|
||||
Args:
|
||||
iteration_uri: URI of this iteration (from agent_iteration_uri)
|
||||
parent_uri: URI of the parent (previous iteration or session)
|
||||
thought: The agent's reasoning/thought
|
||||
thought: The agent's reasoning/thought (used if thought_document_id not provided)
|
||||
action: The tool/action name
|
||||
arguments: Arguments passed to the tool (will be JSON-encoded)
|
||||
observation: The result/observation from the tool
|
||||
observation: The result/observation from the tool (used if observation_document_id not provided)
|
||||
thought_document_id: Optional document URI for thought in librarian (preferred)
|
||||
observation_document_id: Optional document URI for observation in librarian (preferred)
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
if arguments is None:
|
||||
arguments = {}
|
||||
|
||||
triples = [
|
||||
_triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)),
|
||||
_triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")),
|
||||
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
|
||||
_triple(iteration_uri, TG_THOUGHT, _literal(thought)),
|
||||
_triple(iteration_uri, TG_ACTION, _literal(action)),
|
||||
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
|
||||
_triple(iteration_uri, TG_OBSERVATION, _literal(observation)),
|
||||
]
|
||||
|
||||
# Thought: use document reference or inline
|
||||
if thought_document_id:
|
||||
triples.append(_triple(iteration_uri, TG_THOUGHT_DOCUMENT, _iri(thought_document_id)))
|
||||
elif thought:
|
||||
triples.append(_triple(iteration_uri, TG_THOUGHT, _literal(thought)))
|
||||
|
||||
# Observation: use document reference or inline
|
||||
if observation_document_id:
|
||||
triples.append(_triple(iteration_uri, TG_OBSERVATION_DOCUMENT, _iri(observation_document_id)))
|
||||
elif observation:
|
||||
triples.append(_triple(iteration_uri, TG_OBSERVATION, _literal(observation)))
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def agent_final_triples(
|
||||
final_uri: str,
|
||||
parent_uri: str,
|
||||
answer: str,
|
||||
answer: str = "",
|
||||
document_id: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for an agent final answer (Conclusion).
|
||||
|
|
@ -122,20 +142,29 @@ def agent_final_triples(
|
|||
Creates:
|
||||
- Entity declaration with tg:Conclusion type
|
||||
- wasDerivedFrom link to parent (last iteration or session)
|
||||
- The answer text
|
||||
- Either document reference (if document_id provided) or inline answer
|
||||
|
||||
Args:
|
||||
final_uri: URI of the final answer (from agent_final_uri)
|
||||
parent_uri: URI of the parent (last iteration or session if no iterations)
|
||||
answer: The final answer text
|
||||
answer: The final answer text (used if document_id not provided)
|
||||
document_id: Optional document URI in librarian (preferred)
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
return [
|
||||
triples = [
|
||||
_triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)),
|
||||
_triple(final_uri, RDFS_LABEL, _literal("Conclusion")),
|
||||
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
|
||||
_triple(final_uri, TG_ANSWER, _literal(answer)),
|
||||
]
|
||||
|
||||
if document_id:
|
||||
# Store reference to document in librarian (as IRI)
|
||||
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
|
||||
elif answer:
|
||||
# Fallback: store inline answer
|
||||
triples.append(_triple(final_uri, TG_ANSWER, _literal(answer)))
|
||||
|
||||
return triples
|
||||
|
|
|
|||
|
|
@ -92,6 +92,10 @@ TG_ARGUMENTS = TG + "arguments"
|
|||
TG_OBSERVATION = TG + "observation"
|
||||
TG_ANSWER = TG + "answer"
|
||||
|
||||
# Agent document references (for librarian storage)
|
||||
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
|
||||
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
|
||||
|
||||
# Named graph URIs for RDF datasets
|
||||
# These separate different types of data while keeping them in the same collection
|
||||
GRAPH_DEFAULT = "" # Core knowledge facts (triples extracted from documents)
|
||||
|
|
|
|||
|
|
@ -30,11 +30,15 @@ class AgentRequest:
|
|||
@dataclass
|
||||
class AgentResponse:
|
||||
# Streaming-first design
|
||||
chunk_type: str = "" # "thought", "action", "observation", "answer", "error"
|
||||
chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error"
|
||||
content: str = "" # The actual content (interpretation depends on chunk_type)
|
||||
end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete
|
||||
end_of_dialog: bool = False # Entire agent dialog is complete
|
||||
|
||||
# Explainability fields
|
||||
explain_id: str | None = None # Provenance URI (announced as created)
|
||||
explain_graph: str | None = None # Named graph where explain was stored
|
||||
|
||||
# Legacy fields (deprecated but kept for backward compatibility)
|
||||
answer: str = ""
|
||||
error: Error | None = None
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ class DocumentRagQuery:
|
|||
class DocumentRagResponse:
|
||||
error: Error | None = None
|
||||
response: str | None = ""
|
||||
end_of_stream: bool = False
|
||||
end_of_stream: bool = False # LLM response stream complete
|
||||
explain_id: str | None = None # Single explain URI (announced as created)
|
||||
explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval)
|
||||
message_type: str = "" # "chunk" or "explain"
|
||||
end_of_session: bool = False # Entire session complete
|
||||
|
|
|
|||
|
|
@ -4,8 +4,19 @@ Uses the agent service to answer a question
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
ProvenanceEvent,
|
||||
Question,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
AgentThought,
|
||||
AgentObservation,
|
||||
AgentAnswer,
|
||||
)
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
|
@ -97,11 +108,148 @@ def output(text, prefix="> ", width=78):
|
|||
)
|
||||
print(out)
|
||||
|
||||
def question_explainable(
|
||||
url, question_text, flow_id, user, collection,
|
||||
state=None, group=None, verbose=False, token=None, debug=False
|
||||
):
|
||||
"""Execute agent with explainability - shows provenance events inline."""
|
||||
api = Api(url=url, token=token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
|
||||
|
||||
try:
|
||||
# Track last chunk type for formatting
|
||||
last_chunk_type = None
|
||||
current_outputter = None
|
||||
|
||||
# Stream agent with explainability - process events as they arrive
|
||||
for item in flow.agent_explain(
|
||||
question=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
state=state,
|
||||
group=group,
|
||||
):
|
||||
if isinstance(item, AgentThought):
|
||||
if last_chunk_type != "thought":
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
print() # Blank line between message types
|
||||
if verbose:
|
||||
current_outputter = Outputter(width=78, prefix="\U0001f914 ")
|
||||
current_outputter.__enter__()
|
||||
last_chunk_type = "thought"
|
||||
if current_outputter:
|
||||
current_outputter.output(item.content)
|
||||
if current_outputter.word_buffer:
|
||||
print(current_outputter.word_buffer, end="", flush=True)
|
||||
current_outputter.column += len(current_outputter.word_buffer)
|
||||
current_outputter.word_buffer = ""
|
||||
|
||||
elif isinstance(item, AgentObservation):
|
||||
if last_chunk_type != "observation":
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
print()
|
||||
if verbose:
|
||||
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
|
||||
current_outputter.__enter__()
|
||||
last_chunk_type = "observation"
|
||||
if current_outputter:
|
||||
current_outputter.output(item.content)
|
||||
if current_outputter.word_buffer:
|
||||
print(current_outputter.word_buffer, end="", flush=True)
|
||||
current_outputter.column += len(current_outputter.word_buffer)
|
||||
current_outputter.word_buffer = ""
|
||||
|
||||
elif isinstance(item, AgentAnswer):
|
||||
if last_chunk_type != "answer":
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
print()
|
||||
last_chunk_type = "answer"
|
||||
# Print answer content directly
|
||||
print(item.content, end="", flush=True)
|
||||
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
# Process provenance event immediately
|
||||
prov_id = item.explain_id
|
||||
explain_graph = item.explain_graph or "urn:graph:retrieval"
|
||||
|
||||
entity = explain_client.fetch_entity(
|
||||
prov_id,
|
||||
graph=explain_graph,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
if debug:
|
||||
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Display based on entity type
|
||||
if isinstance(entity, Question):
|
||||
print(f"\n [session] {prov_id}", file=sys.stderr)
|
||||
if entity.query:
|
||||
print(f" Query: {entity.query}", file=sys.stderr)
|
||||
if entity.timestamp:
|
||||
print(f" Time: {entity.timestamp}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Analysis):
|
||||
print(f"\n [iteration] {prov_id}", file=sys.stderr)
|
||||
if entity.thought:
|
||||
thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought
|
||||
print(f" Thought: {thought_short}", file=sys.stderr)
|
||||
if entity.action:
|
||||
print(f" Action: {entity.action}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Conclusion):
|
||||
print(f"\n [conclusion] {prov_id}", file=sys.stderr)
|
||||
if entity.answer:
|
||||
print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
|
||||
|
||||
# Close any remaining outputter
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
|
||||
# Final newline if we ended with answer
|
||||
if last_chunk_type == "answer":
|
||||
print()
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
|
||||
def question(
|
||||
url, question, flow_id, user, collection,
|
||||
plan=None, state=None, group=None, verbose=False, streaming=True,
|
||||
token=None
|
||||
token=None, explainable=False, debug=False
|
||||
):
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
if explainable:
|
||||
question_explainable(
|
||||
url=url,
|
||||
question_text=question,
|
||||
flow_id=flow_id,
|
||||
user=user,
|
||||
collection=collection,
|
||||
state=state,
|
||||
group=group,
|
||||
verbose=verbose,
|
||||
token=token,
|
||||
debug=debug
|
||||
)
|
||||
return
|
||||
|
||||
if verbose:
|
||||
output(wrap(question), "\U00002753 ")
|
||||
|
|
@ -270,6 +418,18 @@ def main():
|
|||
help=f'Disable streaming (use legacy mode)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--explainable',
|
||||
action='store_true',
|
||||
help='Show provenance events: Session, Iterations, Conclusion (implies streaming)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Show debug output for troubleshooting'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -286,6 +446,8 @@ def main():
|
|||
verbose = args.verbose,
|
||||
streaming = not args.no_streaming,
|
||||
token = args.token,
|
||||
explainable = args.explainable,
|
||||
debug = args.debug,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,16 @@ Uses the DocumentRAG service to answer a question
|
|||
|
||||
import argparse
|
||||
import os
|
||||
from trustgraph.api import Api
|
||||
import sys
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
RAGChunk,
|
||||
ProvenanceEvent,
|
||||
Question,
|
||||
Exploration,
|
||||
Synthesis,
|
||||
)
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
|
@ -12,7 +21,90 @@ default_user = 'trustgraph'
|
|||
default_collection = 'default'
|
||||
default_doc_limit = 10
|
||||
|
||||
def question(url, flow_id, question, user, collection, doc_limit, streaming=True, token=None):
|
||||
|
||||
def question_explainable(
|
||||
url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False
|
||||
):
|
||||
"""Execute document RAG with explainability - shows provenance events inline."""
|
||||
api = Api(url=url, token=token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
|
||||
|
||||
try:
|
||||
# Stream DocumentRAG with explainability - process events as they arrive
|
||||
for item in flow.document_rag_explain(
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
print(item.content, end="", flush=True)
|
||||
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
# Process provenance event immediately
|
||||
prov_id = item.explain_id
|
||||
explain_graph = item.explain_graph or "urn:graph:retrieval"
|
||||
|
||||
entity = explain_client.fetch_entity(
|
||||
prov_id,
|
||||
graph=explain_graph,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
if debug:
|
||||
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Display based on entity type
|
||||
if isinstance(entity, Question):
|
||||
print(f"\n [question] {prov_id}", file=sys.stderr)
|
||||
if entity.query:
|
||||
print(f" Query: {entity.query}", file=sys.stderr)
|
||||
if entity.timestamp:
|
||||
print(f" Time: {entity.timestamp}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Exploration):
|
||||
print(f"\n [exploration] {prov_id}", file=sys.stderr)
|
||||
if entity.chunk_count:
|
||||
print(f" Chunks retrieved: {entity.chunk_count}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
if entity.content:
|
||||
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
|
||||
|
||||
print() # Final newline
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
|
||||
def question(
|
||||
url, flow_id, question_text, user, collection, doc_limit,
|
||||
streaming=True, token=None, explainable=False, debug=False
|
||||
):
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
if explainable:
|
||||
question_explainable(
|
||||
url=url,
|
||||
flow_id=flow_id,
|
||||
question_text=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
token=token,
|
||||
debug=debug
|
||||
)
|
||||
return
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -24,7 +116,7 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
|
|||
|
||||
try:
|
||||
response = flow.document_rag(
|
||||
query=question,
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
|
|
@ -42,13 +134,14 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True
|
|||
# Use REST API for non-streaming
|
||||
flow = api.flow().id(flow_id)
|
||||
resp = flow.document_rag(
|
||||
query=question,
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
)
|
||||
print(resp)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -105,6 +198,18 @@ def main():
|
|||
help='Disable streaming (use non-streaming mode)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--explainable',
|
||||
action='store_true',
|
||||
help='Show provenance events: Question, Exploration, Synthesis (implies streaming)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Show debug output for troubleshooting'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -112,12 +217,14 @@ def main():
|
|||
question(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
question=args.question,
|
||||
question_text=args.question,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
doc_limit=args.doc_limit,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,16 @@ import os
|
|||
import sys
|
||||
import websockets
|
||||
import asyncio
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
RAGChunk,
|
||||
ProvenanceEvent,
|
||||
Question,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
)
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
|
@ -602,18 +611,111 @@ async def _question_explainable(
|
|||
print() # Final newline
|
||||
|
||||
|
||||
def _question_explainable_api(
|
||||
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, token=None, debug=False
|
||||
):
|
||||
"""Execute graph RAG with explainability using the new API classes."""
|
||||
api = Api(url=url, token=token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
|
||||
|
||||
try:
|
||||
# Stream GraphRAG with explainability - process events as they arrive
|
||||
for item in flow.graph_rag_explain(
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_subgraph_count=5,
|
||||
max_entity_distance=max_path_length,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
print(item.content, end="", flush=True)
|
||||
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
# Process provenance event immediately
|
||||
prov_id = item.explain_id
|
||||
explain_graph = item.explain_graph or "urn:graph:retrieval"
|
||||
|
||||
entity = explain_client.fetch_entity(
|
||||
prov_id,
|
||||
graph=explain_graph,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
if debug:
|
||||
print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Display based on entity type
|
||||
if isinstance(entity, Question):
|
||||
print(f"\n [question] {prov_id}", file=sys.stderr)
|
||||
if entity.query:
|
||||
print(f" Query: {entity.query}", file=sys.stderr)
|
||||
if entity.timestamp:
|
||||
print(f" Time: {entity.timestamp}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Exploration):
|
||||
print(f"\n [exploration] {prov_id}", file=sys.stderr)
|
||||
if entity.edge_count:
|
||||
print(f" Edges explored: {entity.edge_count}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Focus):
|
||||
print(f"\n [focus] {prov_id}", file=sys.stderr)
|
||||
if entity.selected_edge_uris:
|
||||
print(f" Focused on {len(entity.selected_edge_uris)} edge(s)", file=sys.stderr)
|
||||
|
||||
# Fetch full focus with edge details
|
||||
focus_full = explain_client.fetch_focus_with_edges(
|
||||
prov_id,
|
||||
graph=explain_graph,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
if focus_full and focus_full.edge_selections:
|
||||
for edge_sel in focus_full.edge_selections:
|
||||
if edge_sel.edge:
|
||||
# Resolve labels for edge components
|
||||
s_label, p_label, o_label = explain_client.resolve_edge_labels(
|
||||
edge_sel.edge, user, collection
|
||||
)
|
||||
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reason: {r_short}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
if entity.content:
|
||||
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr)
|
||||
|
||||
print() # Final newline
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
|
||||
def question(
|
||||
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, streaming=True, token=None,
|
||||
explainable=False, debug=False
|
||||
):
|
||||
|
||||
# Explainable mode uses direct websocket to capture provenance events
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
if explainable:
|
||||
asyncio.run(_question_explainable(
|
||||
_question_explainable_api(
|
||||
url=url,
|
||||
flow_id=flow_id,
|
||||
question=question,
|
||||
question_text=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
|
|
@ -622,7 +724,7 @@ def question(
|
|||
max_path_length=max_path_length,
|
||||
token=token,
|
||||
debug=debug
|
||||
))
|
||||
)
|
||||
return
|
||||
|
||||
# Create API client
|
||||
|
|
|
|||
|
|
@ -14,180 +14,17 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from tabulate import tabulate
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import Api, ExplainabilityClient
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
# Predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_EXPLORATION = TG + "Exploration"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
|
||||
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
|
||||
# Retrieval graph
|
||||
RETRIEVAL_GRAPH = "urn:graph:retrieval"
|
||||
|
||||
|
||||
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
|
||||
"""Query triples using the socket API."""
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
if s is not None:
|
||||
request["s"] = {"t": "i", "i": s}
|
||||
if p is not None:
|
||||
request["p"] = {"t": "i", "i": p}
|
||||
if o is not None:
|
||||
if isinstance(o, str):
|
||||
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
|
||||
request["o"] = {"t": "i", "i": o}
|
||||
else:
|
||||
request["o"] = {"t": "l", "v": o}
|
||||
elif isinstance(o, dict):
|
||||
request["o"] = o
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
|
||||
triples = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
p_val = extract_value(t.get("p", {}))
|
||||
o_val = extract_value(t.get("o", {}))
|
||||
triples.append((s_val, p_val, o_val))
|
||||
except Exception as e:
|
||||
print(f"Error querying triples: {e}", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def extract_value(term):
|
||||
"""Extract value from a term dict."""
|
||||
if not term:
|
||||
return ""
|
||||
|
||||
t = term.get("t") or term.get("type")
|
||||
|
||||
if t == "i":
|
||||
return term.get("i") or term.get("iri", "")
|
||||
elif t == "l":
|
||||
return term.get("v") or term.get("value", "")
|
||||
elif t == "t":
|
||||
# Quoted triple
|
||||
tr = term.get("tr") or term.get("triple", {})
|
||||
return {
|
||||
"s": extract_value(tr.get("s", {})),
|
||||
"p": extract_value(tr.get("p", {})),
|
||||
"o": extract_value(tr.get("o", {})),
|
||||
}
|
||||
|
||||
# Fallback for raw values
|
||||
if "i" in term:
|
||||
return term["i"]
|
||||
if "v" in term:
|
||||
return term["v"]
|
||||
|
||||
return str(term)
|
||||
|
||||
|
||||
def get_timestamp(socket, flow_id, user, collection, question_id):
|
||||
"""Get timestamp for a question."""
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=question_id, p=PROV_STARTED_AT_TIME, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
for s, p, o in triples:
|
||||
return o
|
||||
return ""
|
||||
|
||||
|
||||
def get_session_type(socket, flow_id, user, collection, session_id):
|
||||
"""
|
||||
Get the type of session (Agent or GraphRAG).
|
||||
|
||||
Both have tg:Question type, so we distinguish by URI pattern
|
||||
or by checking what's derived from it.
|
||||
"""
|
||||
# Fast path: check URI pattern
|
||||
if session_id.startswith("urn:trustgraph:agent:"):
|
||||
return "Agent"
|
||||
if session_id.startswith("urn:trustgraph:question:"):
|
||||
return "GraphRAG"
|
||||
|
||||
# Check what's derived from this entity
|
||||
derived = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=PROV_WAS_DERIVED_FROM, o=session_id, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
generated = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=PROV_WAS_GENERATED_BY, o=session_id, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
|
||||
for s, p, o in derived + generated:
|
||||
child_types = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=s, p=RDF_TYPE, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
for _, _, child_type in child_types:
|
||||
if child_type == TG_ANALYSIS:
|
||||
return "Agent"
|
||||
if child_type == TG_EXPLORATION:
|
||||
return "GraphRAG"
|
||||
|
||||
return "GraphRAG"
|
||||
|
||||
|
||||
def list_sessions(socket, flow_id, user, collection, limit):
|
||||
"""List all explainability sessions (GraphRAG and Agent) by finding questions."""
|
||||
# Query for all triples with predicate = tg:query
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
p=TG_QUERY, g=RETRIEVAL_GRAPH, limit=limit
|
||||
)
|
||||
|
||||
sessions = []
|
||||
for question_id, _, query_text in triples:
|
||||
# Get timestamp if available
|
||||
timestamp = get_timestamp(socket, flow_id, user, collection, question_id)
|
||||
# Get session type (Agent or GraphRAG)
|
||||
session_type = get_session_type(socket, flow_id, user, collection, question_id)
|
||||
|
||||
sessions.append({
|
||||
"id": question_id,
|
||||
"type": session_type,
|
||||
"question": query_text,
|
||||
"time": timestamp,
|
||||
})
|
||||
|
||||
# Sort by timestamp (newest first) if available
|
||||
sessions.sort(key=lambda x: x.get("time", ""), reverse=True)
|
||||
|
||||
return sessions
|
||||
|
||||
|
||||
def truncate_text(text, max_len=60):
|
||||
"""Truncate text to max length with ellipsis."""
|
||||
if not text:
|
||||
|
|
@ -277,16 +114,42 @@ def main():
|
|||
try:
|
||||
api = Api(args.api_url, token=args.token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(args.flow_id)
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
try:
|
||||
sessions = list_sessions(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
# List all sessions using the API
|
||||
questions = explain_client.list_sessions(
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
# Convert to output format
|
||||
sessions = []
|
||||
for q in questions:
|
||||
session_type = explain_client.detect_session_type(
|
||||
q.uri,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection
|
||||
)
|
||||
|
||||
# Map type names
|
||||
type_display = {
|
||||
"graphrag": "GraphRAG",
|
||||
"docrag": "DocRAG",
|
||||
"agent": "Agent",
|
||||
}.get(session_type, session_type.title())
|
||||
|
||||
sessions.append({
|
||||
"id": q.uri,
|
||||
"type": type_display,
|
||||
"question": q.query,
|
||||
"time": q.timestamp,
|
||||
})
|
||||
|
||||
if args.format == 'json':
|
||||
print_json(sessions)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -291,42 +291,25 @@ def query_graph(
|
|||
):
|
||||
"""Query the triple store with pattern matching.
|
||||
|
||||
Uses the WebSocket API's raw streaming mode for efficient delivery of results.
|
||||
Uses the API's triples_query_stream for efficient streaming delivery.
|
||||
"""
|
||||
socket = Api(url, token=token).socket()
|
||||
|
||||
# Build request dict directly (bypassing triples_query_stream's string conversion)
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": True,
|
||||
"batch-size": batch_size,
|
||||
}
|
||||
|
||||
# Add term dicts for s/p/o (None means wildcard)
|
||||
if subject is not None:
|
||||
request["s"] = subject
|
||||
if predicate is not None:
|
||||
request["p"] = predicate
|
||||
if obj is not None:
|
||||
request["o"] = obj
|
||||
if graph is not None:
|
||||
request["g"] = graph
|
||||
flow = socket.flow(flow_id)
|
||||
|
||||
all_triples = []
|
||||
|
||||
try:
|
||||
# Use raw streaming mode - yields response dicts directly
|
||||
for response in socket._send_request_sync(
|
||||
"triples", flow_id, request, streaming_raw=True
|
||||
# Use triples_query_stream - accepts Term dicts directly
|
||||
for triples in flow.triples_query_stream(
|
||||
s=subject,
|
||||
p=predicate,
|
||||
o=obj,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=limit,
|
||||
batch_size=batch_size,
|
||||
):
|
||||
# Response may have triples in different locations depending on format
|
||||
if isinstance(response, dict):
|
||||
triples = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triples = response
|
||||
|
||||
if not isinstance(triples, list):
|
||||
triples = [triples] if triples else []
|
||||
|
||||
|
|
|
|||
|
|
@ -18,228 +18,99 @@ import argparse
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
Question,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
)
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
# Predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
TG_EDGE_COUNT = TG + "edgeCount"
|
||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_CONTENT = TG + "content"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_REIFIES = TG + "reifies"
|
||||
# Explainability entity types
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_EXPLORATION = TG + "Exploration"
|
||||
TG_FOCUS = TG + "Focus"
|
||||
TG_SYNTHESIS = TG + "Synthesis"
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_CONCLUSION = TG + "Conclusion"
|
||||
|
||||
# Agent predicates
|
||||
TG_THOUGHT = TG + "thought"
|
||||
TG_ACTION = TG + "action"
|
||||
TG_ARGUMENTS = TG + "arguments"
|
||||
TG_OBSERVATION = TG + "observation"
|
||||
TG_ANSWER = TG + "answer"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
|
||||
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
# Graphs
|
||||
RETRIEVAL_GRAPH = "urn:graph:retrieval"
|
||||
SOURCE_GRAPH = "urn:graph:source"
|
||||
|
||||
|
||||
def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000):
|
||||
"""Query triples using the socket API."""
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
if s is not None:
|
||||
request["s"] = {"t": "i", "i": s}
|
||||
if p is not None:
|
||||
request["p"] = {"t": "i", "i": p}
|
||||
if o is not None:
|
||||
if isinstance(o, str):
|
||||
if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"):
|
||||
request["o"] = {"t": "i", "i": o}
|
||||
else:
|
||||
request["o"] = {"t": "l", "v": o}
|
||||
elif isinstance(o, dict):
|
||||
request["o"] = o
|
||||
if g is not None:
|
||||
request["g"] = g
|
||||
|
||||
triples = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
p_val = extract_value(t.get("p", {}))
|
||||
o_val = extract_value(t.get("o", {}))
|
||||
triples.append((s_val, p_val, o_val))
|
||||
except Exception as e:
|
||||
print(f"Error querying triples: {e}", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
# Provenance predicates for edge tracing
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_REIFIES = TG + "reifies"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
|
||||
|
||||
def extract_value(term):
|
||||
"""Extract value from a term dict."""
|
||||
if not term:
|
||||
return ""
|
||||
def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client):
|
||||
"""
|
||||
Trace an edge back to its source document via reification.
|
||||
|
||||
t = term.get("t") or term.get("type")
|
||||
Args:
|
||||
flow: SocketFlowInstance
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
edge: Dict with s, p, o keys
|
||||
label_cache: Dict for caching labels
|
||||
explain_client: ExplainabilityClient for label resolution
|
||||
|
||||
if t == "i":
|
||||
return term.get("i") or term.get("iri", "")
|
||||
elif t == "l":
|
||||
return term.get("v") or term.get("value", "")
|
||||
elif t == "t":
|
||||
# Quoted triple
|
||||
tr = term.get("tr") or term.get("triple", {})
|
||||
return {
|
||||
"s": extract_value(tr.get("s", {})),
|
||||
"p": extract_value(tr.get("p", {})),
|
||||
"o": extract_value(tr.get("o", {})),
|
||||
}
|
||||
Returns:
|
||||
List of provenance chains, each chain is list of {uri, label}
|
||||
"""
|
||||
edge_s = edge.get("s", "")
|
||||
edge_p = edge.get("p", "")
|
||||
edge_o = edge.get("o", "")
|
||||
|
||||
# Fallback for raw values
|
||||
if "i" in term:
|
||||
return term["i"]
|
||||
if "v" in term:
|
||||
return term["v"]
|
||||
# Build quoted triple for lookup
|
||||
def build_term(val):
|
||||
if isinstance(val, str) and (val.startswith("http") or val.startswith("urn:")):
|
||||
return {"t": "i", "i": val}
|
||||
return {"t": "l", "v": str(val)}
|
||||
|
||||
return str(term)
|
||||
|
||||
|
||||
def get_node_properties(socket, flow_id, user, collection, node_uri, graph=RETRIEVAL_GRAPH):
|
||||
"""Get all properties of a node as a dict."""
|
||||
triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=graph)
|
||||
props = {}
|
||||
for s, p, o in triples:
|
||||
if p not in props:
|
||||
props[p] = []
|
||||
props[p].append(o)
|
||||
return props
|
||||
|
||||
|
||||
def find_by_predicate_object(socket, flow_id, user, collection, predicate, obj, graph=RETRIEVAL_GRAPH):
|
||||
"""Find subjects where predicate = obj."""
|
||||
triples = query_triples(socket, flow_id, user, collection, p=predicate, o=obj, g=graph)
|
||||
return [s for s, p, o in triples]
|
||||
|
||||
|
||||
def get_label(socket, flow_id, user, collection, uri, label_cache):
|
||||
"""Get label for a URI, with caching."""
|
||||
if not isinstance(uri, str) or not (uri.startswith("http://") or uri.startswith("https://") or uri.startswith("urn:")):
|
||||
return uri
|
||||
|
||||
if uri in label_cache:
|
||||
return label_cache[uri]
|
||||
|
||||
triples = query_triples(socket, flow_id, user, collection, s=uri, p=RDFS_LABEL)
|
||||
for s, p, o in triples:
|
||||
label_cache[uri] = o
|
||||
return o
|
||||
|
||||
label_cache[uri] = uri
|
||||
return uri
|
||||
|
||||
|
||||
def get_document_content(api, user, doc_id, max_content):
|
||||
"""Fetch document content from librarian API."""
|
||||
try:
|
||||
library = api.library()
|
||||
content = library.get_document_content(user=user, id=doc_id)
|
||||
|
||||
# Try to decode as text
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
if len(text) > max_content:
|
||||
return text[:max_content] + "... [truncated]"
|
||||
return text
|
||||
except UnicodeDecodeError:
|
||||
return f"[Binary: {len(content)} bytes]"
|
||||
except Exception as e:
|
||||
return f"[Error fetching content: {e}]"
|
||||
|
||||
|
||||
def trace_edge_provenance(socket, flow_id, user, collection, edge_s, edge_p, edge_o, label_cache):
|
||||
"""Trace an edge back to its source document via reification."""
|
||||
# Build the quoted triple for lookup
|
||||
quoted_triple = {
|
||||
"t": "t",
|
||||
"tr": {
|
||||
"s": {"t": "i", "i": edge_s} if isinstance(edge_s, str) and (edge_s.startswith("http") or edge_s.startswith("urn:")) else {"t": "l", "v": edge_s},
|
||||
"p": {"t": "i", "i": edge_p},
|
||||
"o": {"t": "i", "i": edge_o} if isinstance(edge_o, str) and (edge_o.startswith("http") or edge_o.startswith("urn:")) else {"t": "l", "v": edge_o},
|
||||
"s": build_term(edge_s),
|
||||
"p": build_term(edge_p),
|
||||
"o": build_term(edge_o),
|
||||
}
|
||||
}
|
||||
|
||||
# Query: ?stmt tg:reifies <<edge>>
|
||||
request = {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 10,
|
||||
"streaming": False,
|
||||
"p": {"t": "i", "i": TG_REIFIES},
|
||||
"o": quoted_triple,
|
||||
"g": SOURCE_GRAPH,
|
||||
}
|
||||
|
||||
stmt_uris = []
|
||||
try:
|
||||
for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True):
|
||||
if isinstance(response, dict):
|
||||
triple_list = response.get("response", response.get("triples", []))
|
||||
else:
|
||||
triple_list = response
|
||||
|
||||
if not isinstance(triple_list, list):
|
||||
triple_list = [triple_list] if triple_list else []
|
||||
|
||||
for t in triple_list:
|
||||
s_val = extract_value(t.get("s", {}))
|
||||
if s_val:
|
||||
stmt_uris.append(s_val)
|
||||
results = flow.triples_query(
|
||||
p=TG_REIFIES,
|
||||
o=quoted_triple,
|
||||
g=SOURCE_GRAPH,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
# For each statement, find wasDerivedFrom chain
|
||||
# Extract statement URIs
|
||||
stmt_uris = []
|
||||
for t in results:
|
||||
s_term = t.get("s", {})
|
||||
s_val = s_term.get("i") or s_term.get("v", "")
|
||||
if s_val:
|
||||
stmt_uris.append(s_val)
|
||||
|
||||
# For each statement, trace wasDerivedFrom chain
|
||||
provenance_chains = []
|
||||
for stmt_uri in stmt_uris:
|
||||
chain = trace_provenance_chain(socket, flow_id, user, collection, stmt_uri, label_cache)
|
||||
chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client)
|
||||
if chain:
|
||||
provenance_chains.append(chain)
|
||||
|
||||
return provenance_chains
|
||||
|
||||
|
||||
def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_cache, max_depth=10):
|
||||
def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10):
|
||||
"""Trace prov:wasDerivedFrom chain from start_uri to root."""
|
||||
chain = []
|
||||
current = start_uri
|
||||
|
|
@ -248,17 +119,32 @@ def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_c
|
|||
if not current:
|
||||
break
|
||||
|
||||
label = get_label(socket, flow_id, user, collection, current, label_cache)
|
||||
# Get label
|
||||
if current in label_cache:
|
||||
label = label_cache[current]
|
||||
else:
|
||||
label = explain_client.resolve_label(current, user, collection)
|
||||
label_cache[current] = label
|
||||
|
||||
chain.append({"uri": current, "label": label})
|
||||
|
||||
# Get parent
|
||||
triples = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH
|
||||
)
|
||||
# Get parent via wasDerivedFrom
|
||||
try:
|
||||
results = flow.triples_query(
|
||||
s=current,
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
g=SOURCE_GRAPH,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=1
|
||||
)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
parent = None
|
||||
for s, p, o in triples:
|
||||
parent = o
|
||||
for t in results:
|
||||
o_term = t.get("o", {})
|
||||
parent = o_term.get("i") or o_term.get("v", "")
|
||||
break
|
||||
|
||||
if not parent or parent == current:
|
||||
|
|
@ -276,331 +162,24 @@ def format_provenance_chain(chain):
|
|||
return " -> ".join(labels)
|
||||
|
||||
|
||||
def format_edge(edge, label_cache=None, socket=None, flow_id=None, user=None, collection=None):
|
||||
"""Format a quoted triple edge for display."""
|
||||
if not isinstance(edge, dict):
|
||||
return str(edge)
|
||||
def print_graphrag_text(trace, explain_client, flow, user, collection, show_provenance=False):
|
||||
"""Print GraphRAG trace in text format."""
|
||||
question = trace.get("question")
|
||||
|
||||
s = edge.get("s", "?")
|
||||
p = edge.get("p", "?")
|
||||
o = edge.get("o", "?")
|
||||
|
||||
# Get labels if available
|
||||
if label_cache and socket:
|
||||
s_label = get_label(socket, flow_id, user, collection, s, label_cache)
|
||||
p_label = get_label(socket, flow_id, user, collection, p, label_cache)
|
||||
o_label = get_label(socket, flow_id, user, collection, o, label_cache)
|
||||
else:
|
||||
# Shorten URIs for display
|
||||
s_label = s.split("/")[-1] if "/" in str(s) else s
|
||||
p_label = p.split("/")[-1] if "/" in str(p) else p
|
||||
o_label = o.split("/")[-1] if "/" in str(o) else o
|
||||
|
||||
return f"({s_label}, {p_label}, {o_label})"
|
||||
|
||||
|
||||
def detect_trace_type(socket, flow_id, user, collection, entity_id):
|
||||
"""
|
||||
Detect whether an entity is an agent Question or GraphRAG Question.
|
||||
|
||||
Both have rdf:type = tg:Question, so we distinguish by checking
|
||||
what's derived from it:
|
||||
- Agent: has tg:Analysis or tg:Conclusion derived
|
||||
- GraphRAG: has tg:Exploration derived
|
||||
|
||||
Also checks URI pattern as fallback:
|
||||
- urn:trustgraph:agent: -> agent
|
||||
- urn:trustgraph:question: -> graphrag
|
||||
|
||||
Returns:
|
||||
"agent" or "graphrag"
|
||||
"""
|
||||
# Check URI pattern first (fast path)
|
||||
if entity_id.startswith("urn:trustgraph:agent:"):
|
||||
return "agent"
|
||||
if entity_id.startswith("urn:trustgraph:question:"):
|
||||
return "graphrag"
|
||||
|
||||
# Check what's derived from this entity
|
||||
derived = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, entity_id
|
||||
)
|
||||
|
||||
# Also check wasGeneratedBy (GraphRAG exploration uses this)
|
||||
generated = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_GENERATED_BY, entity_id
|
||||
)
|
||||
|
||||
all_children = derived + generated
|
||||
|
||||
for child_id in all_children:
|
||||
child_types = query_triples(
|
||||
socket, flow_id, user, collection,
|
||||
s=child_id, p=RDF_TYPE, g=RETRIEVAL_GRAPH
|
||||
)
|
||||
for s, p, o in child_types:
|
||||
if o == TG_ANALYSIS or o == TG_CONCLUSION:
|
||||
return "agent"
|
||||
if o == TG_EXPLORATION:
|
||||
return "graphrag"
|
||||
|
||||
# Default to graphrag
|
||||
return "graphrag"
|
||||
|
||||
|
||||
def build_agent_trace(socket, flow_id, user, collection, session_id, api=None, max_answer=500):
|
||||
"""Build the full explainability trace for an agent session."""
|
||||
trace = {
|
||||
"session_id": session_id,
|
||||
"type": "agent",
|
||||
"question": None,
|
||||
"time": None,
|
||||
"iterations": [],
|
||||
"final_answer": None,
|
||||
}
|
||||
|
||||
# Get session metadata
|
||||
props = get_node_properties(socket, flow_id, user, collection, session_id)
|
||||
trace["question"] = props.get(TG_QUERY, [None])[0]
|
||||
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
|
||||
|
||||
# Find all entities derived from this session (iterations and final)
|
||||
# Start by looking for entities where prov:wasDerivedFrom = session_id
|
||||
current_uri = session_id
|
||||
iteration_num = 1
|
||||
|
||||
while True:
|
||||
# Find entities derived from current
|
||||
derived_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, current_uri
|
||||
)
|
||||
|
||||
if not derived_ids:
|
||||
break
|
||||
|
||||
derived_id = derived_ids[0]
|
||||
derived_props = get_node_properties(socket, flow_id, user, collection, derived_id)
|
||||
|
||||
# Check type
|
||||
types = derived_props.get(RDF_TYPE, [])
|
||||
|
||||
if TG_ANALYSIS in types:
|
||||
iteration = {
|
||||
"id": derived_id,
|
||||
"iteration_num": iteration_num,
|
||||
"thought": derived_props.get(TG_THOUGHT, [None])[0],
|
||||
"action": derived_props.get(TG_ACTION, [None])[0],
|
||||
"arguments": derived_props.get(TG_ARGUMENTS, [None])[0],
|
||||
"observation": derived_props.get(TG_OBSERVATION, [None])[0],
|
||||
}
|
||||
trace["iterations"].append(iteration)
|
||||
current_uri = derived_id
|
||||
iteration_num += 1
|
||||
|
||||
elif TG_CONCLUSION in types:
|
||||
answer = derived_props.get(TG_ANSWER, [None])[0]
|
||||
if answer and len(answer) > max_answer:
|
||||
answer = answer[:max_answer] + "... [truncated]"
|
||||
trace["final_answer"] = {
|
||||
"id": derived_id,
|
||||
"answer": answer,
|
||||
}
|
||||
break
|
||||
|
||||
else:
|
||||
# Unknown type, stop traversal
|
||||
break
|
||||
|
||||
return trace
|
||||
|
||||
|
||||
def print_agent_text(trace):
|
||||
"""Print agent trace in text format."""
|
||||
print(f"=== Agent Session: {trace['session_id']} ===")
|
||||
print(f"=== GraphRAG Session: {question.uri if question else 'Unknown'} ===")
|
||||
print()
|
||||
|
||||
if trace["question"]:
|
||||
print(f"Question: {trace['question']}")
|
||||
if trace["time"]:
|
||||
print(f"Time: {trace['time']}")
|
||||
print()
|
||||
|
||||
# Analysis steps
|
||||
print("--- Analysis ---")
|
||||
iterations = trace.get("iterations", [])
|
||||
if iterations:
|
||||
for iteration in iterations:
|
||||
print(f"Analysis {iteration['iteration_num']}:")
|
||||
print(f" Thought: {iteration.get('thought', 'N/A')}")
|
||||
print(f" Action: {iteration.get('action', 'N/A')}")
|
||||
|
||||
args = iteration.get('arguments')
|
||||
if args:
|
||||
# Try to pretty-print JSON arguments
|
||||
try:
|
||||
import json
|
||||
args_obj = json.loads(args)
|
||||
args_str = json.dumps(args_obj, indent=4)
|
||||
# Indent each line
|
||||
args_lines = args_str.split('\n')
|
||||
print(f" Arguments:")
|
||||
for line in args_lines:
|
||||
print(f" {line}")
|
||||
except:
|
||||
print(f" Arguments: {args}")
|
||||
else:
|
||||
print(f" Arguments: N/A")
|
||||
|
||||
obs = iteration.get('observation', 'N/A')
|
||||
if obs and len(obs) > 200:
|
||||
obs = obs[:200] + "... [truncated]"
|
||||
print(f" Observation: {obs}")
|
||||
print()
|
||||
else:
|
||||
print("No analysis steps recorded")
|
||||
print()
|
||||
|
||||
# Conclusion
|
||||
print("--- Conclusion ---")
|
||||
final = trace.get("final_answer")
|
||||
if final and final.get("answer"):
|
||||
print("Answer:")
|
||||
for line in final["answer"].split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No conclusion recorded")
|
||||
|
||||
|
||||
def print_agent_json(trace):
|
||||
"""Print agent trace as JSON."""
|
||||
print(json.dumps(trace, indent=2))
|
||||
|
||||
|
||||
def build_trace(socket, flow_id, user, collection, question_id, api=None, show_provenance=False, max_answer=500):
|
||||
"""Build the full explainability trace for a question."""
|
||||
label_cache = {}
|
||||
|
||||
trace = {
|
||||
"question_id": question_id,
|
||||
"question": None,
|
||||
"time": None,
|
||||
"exploration": None,
|
||||
"focus": None,
|
||||
"synthesis": None,
|
||||
}
|
||||
|
||||
# Get question metadata
|
||||
props = get_node_properties(socket, flow_id, user, collection, question_id)
|
||||
trace["question"] = props.get(TG_QUERY, [None])[0]
|
||||
trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0]
|
||||
|
||||
# Find exploration: ?exploration prov:wasGeneratedBy question_id
|
||||
exploration_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_GENERATED_BY, question_id
|
||||
)
|
||||
|
||||
if exploration_ids:
|
||||
exploration_id = exploration_ids[0]
|
||||
exploration_props = get_node_properties(socket, flow_id, user, collection, exploration_id)
|
||||
trace["exploration"] = {
|
||||
"id": exploration_id,
|
||||
"edge_count": exploration_props.get(TG_EDGE_COUNT, [None])[0],
|
||||
}
|
||||
|
||||
# Find focus: ?focus prov:wasDerivedFrom exploration_id
|
||||
focus_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, exploration_id
|
||||
)
|
||||
|
||||
if focus_ids:
|
||||
focus_id = focus_ids[0]
|
||||
focus_props = get_node_properties(socket, flow_id, user, collection, focus_id)
|
||||
|
||||
# Get selected edges
|
||||
edge_selection_uris = focus_props.get(TG_SELECTED_EDGE, [])
|
||||
selected_edges = []
|
||||
|
||||
for edge_sel_uri in edge_selection_uris:
|
||||
edge_sel_props = get_node_properties(socket, flow_id, user, collection, edge_sel_uri)
|
||||
edge = edge_sel_props.get(TG_EDGE, [None])[0]
|
||||
reasoning = edge_sel_props.get(TG_REASONING, [None])[0]
|
||||
|
||||
edge_info = {
|
||||
"edge": edge,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
|
||||
# Trace provenance if requested
|
||||
if show_provenance and isinstance(edge, dict):
|
||||
provenance = trace_edge_provenance(
|
||||
socket, flow_id, user, collection,
|
||||
edge.get("s", ""), edge.get("p", ""), edge.get("o", ""),
|
||||
label_cache
|
||||
)
|
||||
edge_info["provenance"] = provenance
|
||||
|
||||
selected_edges.append(edge_info)
|
||||
|
||||
trace["focus"] = {
|
||||
"id": focus_id,
|
||||
"selected_edges": selected_edges,
|
||||
}
|
||||
|
||||
# Find synthesis: ?synthesis prov:wasDerivedFrom focus_id
|
||||
synthesis_ids = find_by_predicate_object(
|
||||
socket, flow_id, user, collection,
|
||||
PROV_WAS_DERIVED_FROM, focus_id
|
||||
)
|
||||
|
||||
if synthesis_ids:
|
||||
synthesis_id = synthesis_ids[0]
|
||||
synthesis_props = get_node_properties(socket, flow_id, user, collection, synthesis_id)
|
||||
|
||||
# Get content directly or via document reference
|
||||
content = synthesis_props.get(TG_CONTENT, [None])[0]
|
||||
doc_id = synthesis_props.get(TG_DOCUMENT, [None])[0]
|
||||
|
||||
if not content and doc_id and api:
|
||||
content = get_document_content(api, user, doc_id, max_answer)
|
||||
elif content and len(content) > max_answer:
|
||||
content = content[:max_answer] + "... [truncated]"
|
||||
|
||||
trace["synthesis"] = {
|
||||
"id": synthesis_id,
|
||||
"document_id": doc_id,
|
||||
"answer": content,
|
||||
}
|
||||
|
||||
# Store label cache for formatting
|
||||
trace["_label_cache"] = label_cache
|
||||
|
||||
return trace
|
||||
|
||||
|
||||
def print_text(trace, show_provenance=False):
|
||||
"""Print trace in text format."""
|
||||
label_cache = trace.get("_label_cache", {})
|
||||
|
||||
print(f"=== GraphRAG Session: {trace['question_id']} ===")
|
||||
print()
|
||||
|
||||
if trace["question"]:
|
||||
print(f"Question: {trace['question']}")
|
||||
if trace["time"]:
|
||||
print(f"Time: {trace['time']}")
|
||||
if question:
|
||||
print(f"Question: {question.query}")
|
||||
if question.timestamp:
|
||||
print(f"Time: {question.timestamp}")
|
||||
print()
|
||||
|
||||
# Exploration
|
||||
print("--- Exploration ---")
|
||||
exploration = trace.get("exploration")
|
||||
if exploration:
|
||||
edge_count = exploration.get("edge_count", "?")
|
||||
print(f"Retrieved {edge_count} edges from knowledge graph")
|
||||
print(f"Retrieved {exploration.edge_count} edges from knowledge graph")
|
||||
else:
|
||||
print("No exploration data found")
|
||||
print()
|
||||
|
|
@ -609,24 +188,28 @@ def print_text(trace, show_provenance=False):
|
|||
print("--- Focus (Edge Selection) ---")
|
||||
focus = trace.get("focus")
|
||||
if focus:
|
||||
edges = focus.get("selected_edges", [])
|
||||
edges = focus.edge_selections
|
||||
print(f"Selected {len(edges)} edges:")
|
||||
print()
|
||||
|
||||
for i, edge_info in enumerate(edges, 1):
|
||||
edge = edge_info.get("edge")
|
||||
reasoning = edge_info.get("reasoning")
|
||||
label_cache = {}
|
||||
|
||||
if edge:
|
||||
edge_str = format_edge(edge)
|
||||
print(f" {i}. {edge_str}")
|
||||
for i, edge_sel in enumerate(edges, 1):
|
||||
if edge_sel.edge:
|
||||
s_label, p_label, o_label = explain_client.resolve_edge_labels(
|
||||
edge_sel.edge, user, collection
|
||||
)
|
||||
print(f" {i}. ({s_label}, {p_label}, {o_label})")
|
||||
|
||||
if reasoning:
|
||||
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reasoning: {r_short}")
|
||||
|
||||
if show_provenance:
|
||||
provenance = edge_info.get("provenance", [])
|
||||
if show_provenance and edge_sel.edge:
|
||||
provenance = trace_edge_provenance(
|
||||
flow, user, collection, edge_sel.edge,
|
||||
label_cache, explain_client
|
||||
)
|
||||
for chain in provenance:
|
||||
chain_str = format_provenance_chain(chain)
|
||||
if chain_str:
|
||||
|
|
@ -641,11 +224,9 @@ def print_text(trace, show_provenance=False):
|
|||
print("--- Synthesis ---")
|
||||
synthesis = trace.get("synthesis")
|
||||
if synthesis:
|
||||
answer = synthesis.get("answer")
|
||||
if answer:
|
||||
if synthesis.content:
|
||||
print("Answer:")
|
||||
# Indent the answer
|
||||
for line in answer.split("\n"):
|
||||
for line in synthesis.content.split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No answer content found")
|
||||
|
|
@ -653,11 +234,173 @@ def print_text(trace, show_provenance=False):
|
|||
print("No synthesis data found")
|
||||
|
||||
|
||||
def print_json(trace):
|
||||
"""Print trace as JSON."""
|
||||
# Remove internal cache before printing
|
||||
output = {k: v for k, v in trace.items() if not k.startswith("_")}
|
||||
print(json.dumps(output, indent=2))
|
||||
def print_docrag_text(trace):
|
||||
"""Print DocRAG trace in text format."""
|
||||
question = trace.get("question")
|
||||
|
||||
print(f"=== DocRAG Session: {question.uri if question else 'Unknown'} ===")
|
||||
print()
|
||||
|
||||
if question:
|
||||
print(f"Question: {question.query}")
|
||||
if question.timestamp:
|
||||
print(f"Time: {question.timestamp}")
|
||||
print()
|
||||
|
||||
# Exploration
|
||||
print("--- Exploration ---")
|
||||
exploration = trace.get("exploration")
|
||||
if exploration:
|
||||
print(f"Retrieved {exploration.chunk_count} chunks from document store")
|
||||
else:
|
||||
print("No exploration data found")
|
||||
print()
|
||||
|
||||
# Synthesis (no Focus step for DocRAG)
|
||||
print("--- Synthesis ---")
|
||||
synthesis = trace.get("synthesis")
|
||||
if synthesis:
|
||||
if synthesis.content:
|
||||
print("Answer:")
|
||||
for line in synthesis.content.split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No answer content found")
|
||||
else:
|
||||
print("No synthesis data found")
|
||||
|
||||
|
||||
def print_agent_text(trace):
|
||||
"""Print Agent trace in text format."""
|
||||
question = trace.get("question")
|
||||
|
||||
print(f"=== Agent Session: {question.uri if question else 'Unknown'} ===")
|
||||
print()
|
||||
|
||||
if question:
|
||||
print(f"Question: {question.query}")
|
||||
if question.timestamp:
|
||||
print(f"Time: {question.timestamp}")
|
||||
print()
|
||||
|
||||
# Analysis steps
|
||||
print("--- Analysis ---")
|
||||
iterations = trace.get("iterations", [])
|
||||
if iterations:
|
||||
for i, analysis in enumerate(iterations, 1):
|
||||
print(f"Analysis {i}:")
|
||||
print(f" Thought: {analysis.thought or 'N/A'}")
|
||||
print(f" Action: {analysis.action or 'N/A'}")
|
||||
|
||||
if analysis.arguments:
|
||||
# Try to pretty-print JSON arguments
|
||||
try:
|
||||
args_obj = json.loads(analysis.arguments)
|
||||
args_str = json.dumps(args_obj, indent=4)
|
||||
print(f" Arguments:")
|
||||
for line in args_str.split('\n'):
|
||||
print(f" {line}")
|
||||
except Exception:
|
||||
print(f" Arguments: {analysis.arguments}")
|
||||
else:
|
||||
print(f" Arguments: N/A")
|
||||
|
||||
obs = analysis.observation or 'N/A'
|
||||
if obs and len(obs) > 200:
|
||||
obs = obs[:200] + "... [truncated]"
|
||||
print(f" Observation: {obs}")
|
||||
print()
|
||||
else:
|
||||
print("No analysis steps recorded")
|
||||
print()
|
||||
|
||||
# Conclusion
|
||||
print("--- Conclusion ---")
|
||||
conclusion = trace.get("conclusion")
|
||||
if conclusion and conclusion.answer:
|
||||
print("Answer:")
|
||||
for line in conclusion.answer.split("\n"):
|
||||
print(f" {line}")
|
||||
else:
|
||||
print("No conclusion recorded")
|
||||
|
||||
|
||||
def trace_to_dict(trace, trace_type):
|
||||
"""Convert trace entities to JSON-serializable dict."""
|
||||
if trace_type == "agent":
|
||||
question = trace.get("question")
|
||||
return {
|
||||
"type": "agent",
|
||||
"session_id": question.uri if question else None,
|
||||
"question": question.query if question else None,
|
||||
"time": question.timestamp if question else None,
|
||||
"iterations": [
|
||||
{
|
||||
"id": a.uri,
|
||||
"thought": a.thought,
|
||||
"action": a.action,
|
||||
"arguments": a.arguments,
|
||||
"observation": a.observation,
|
||||
}
|
||||
for a in trace.get("iterations", [])
|
||||
],
|
||||
"conclusion": {
|
||||
"id": trace["conclusion"].uri,
|
||||
"answer": trace["conclusion"].answer,
|
||||
} if trace.get("conclusion") else None,
|
||||
}
|
||||
elif trace_type == "docrag":
|
||||
question = trace.get("question")
|
||||
exploration = trace.get("exploration")
|
||||
synthesis = trace.get("synthesis")
|
||||
|
||||
return {
|
||||
"type": "docrag",
|
||||
"question_id": question.uri if question else None,
|
||||
"question": question.query if question else None,
|
||||
"time": question.timestamp if question else None,
|
||||
"exploration": {
|
||||
"id": exploration.uri,
|
||||
"chunk_count": exploration.chunk_count,
|
||||
} if exploration else None,
|
||||
"synthesis": {
|
||||
"id": synthesis.uri,
|
||||
"document_uri": synthesis.document_uri,
|
||||
"answer": synthesis.content,
|
||||
} if synthesis else None,
|
||||
}
|
||||
else:
|
||||
# graphrag
|
||||
question = trace.get("question")
|
||||
exploration = trace.get("exploration")
|
||||
focus = trace.get("focus")
|
||||
synthesis = trace.get("synthesis")
|
||||
|
||||
return {
|
||||
"type": "graphrag",
|
||||
"question_id": question.uri if question else None,
|
||||
"question": question.query if question else None,
|
||||
"time": question.timestamp if question else None,
|
||||
"exploration": {
|
||||
"id": exploration.uri,
|
||||
"edge_count": exploration.edge_count,
|
||||
} if exploration else None,
|
||||
"focus": {
|
||||
"id": focus.uri,
|
||||
"selected_edges": [
|
||||
{
|
||||
"edge": edge_sel.edge,
|
||||
"reasoning": edge_sel.reasoning,
|
||||
}
|
||||
for edge_sel in focus.edge_selections
|
||||
],
|
||||
} if focus else None,
|
||||
"synthesis": {
|
||||
"id": synthesis.uri,
|
||||
"document_uri": synthesis.document_uri,
|
||||
"answer": synthesis.content,
|
||||
} if synthesis else None,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -727,50 +470,69 @@ def main():
|
|||
try:
|
||||
api = Api(args.api_url, token=args.token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(args.flow_id)
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
try:
|
||||
# Detect trace type (agent vs graphrag)
|
||||
trace_type = detect_trace_type(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
# Detect trace type
|
||||
trace_type = explain_client.detect_session_type(
|
||||
args.question_id,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
entity_id=args.question_id,
|
||||
)
|
||||
|
||||
if trace_type == "agent":
|
||||
# Build and print agent trace
|
||||
trace = build_agent_trace(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
# Fetch and display agent trace
|
||||
trace = explain_client.fetch_agent_trace(
|
||||
args.question_id,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
session_id=args.question_id,
|
||||
api=api,
|
||||
max_answer=args.max_answer,
|
||||
max_content=args.max_answer,
|
||||
)
|
||||
|
||||
if args.format == 'json':
|
||||
print_agent_json(trace)
|
||||
print(json.dumps(trace_to_dict(trace, "agent"), indent=2))
|
||||
else:
|
||||
print_agent_text(trace)
|
||||
else:
|
||||
# Build and print GraphRAG trace (existing behavior)
|
||||
trace = build_trace(
|
||||
socket=socket,
|
||||
flow_id=args.flow_id,
|
||||
|
||||
elif trace_type == "docrag":
|
||||
# Fetch and display DocRAG trace
|
||||
trace = explain_client.fetch_docrag_trace(
|
||||
args.question_id,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
question_id=args.question_id,
|
||||
api=api,
|
||||
show_provenance=args.show_provenance,
|
||||
max_answer=args.max_answer,
|
||||
max_content=args.max_answer,
|
||||
)
|
||||
|
||||
if args.format == 'json':
|
||||
print_json(trace)
|
||||
print(json.dumps(trace_to_dict(trace, "docrag"), indent=2))
|
||||
else:
|
||||
print_text(trace, show_provenance=args.show_provenance)
|
||||
print_docrag_text(trace)
|
||||
|
||||
else:
|
||||
# Fetch and display GraphRAG trace
|
||||
trace = explain_client.fetch_graphrag_trace(
|
||||
args.question_id,
|
||||
graph=RETRIEVAL_GRAPH,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
api=api,
|
||||
max_content=args.max_answer,
|
||||
)
|
||||
|
||||
if args.format == 'json':
|
||||
print(json.dumps(trace_to_dict(trace, "graphrag"), indent=2))
|
||||
else:
|
||||
print_graphrag_text(
|
||||
trace, explain_client, flow,
|
||||
args.user, args.collection,
|
||||
show_provenance=args.show_provenance
|
||||
)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
Simple agent infrastructure broadly implements the ReAct flow.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
|
@ -17,9 +19,13 @@ from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
|||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||
from ... base import ProducerSpec
|
||||
from ... base import Consumer, Producer
|
||||
from ... base import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from ... schema import Triples, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
|
||||
# Provenance imports for agent explainability
|
||||
from trustgraph.provenance import (
|
||||
|
|
@ -41,6 +47,8 @@ from . types import Final, Action, Tool, Argument
|
|||
|
||||
default_ident = "agent-manager"
|
||||
default_max_iterations = 10
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(AgentService):
|
||||
|
||||
|
|
@ -129,6 +137,115 @@ class Processor(AgentService):
|
|||
)
|
||||
)
|
||||
|
||||
# Librarian client for storing answer content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_librarian_requests = {}
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id in self.pending_librarian_requests:
|
||||
future = self.pending_librarian_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
Save answer content to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the answer document
|
||||
user: User ID
|
||||
content: Answer text content
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or "Agent Answer",
|
||||
document_type="answer",
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_librarian_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_librarian_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
logger.info(f"Loading configuration version {version}")
|
||||
|
|
@ -347,6 +464,15 @@ class Processor(AgentService):
|
|||
))
|
||||
logger.debug(f"Emitted session triples for {session_uri}")
|
||||
|
||||
# Send explain event for session
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=session_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
logger.info(f"Question: {request.question}")
|
||||
|
||||
if len(history) >= self.max_iterations:
|
||||
|
|
@ -504,8 +630,28 @@ class Processor(AgentService):
|
|||
else:
|
||||
parent_uri = session_uri
|
||||
|
||||
# Save answer to librarian
|
||||
answer_doc_id = None
|
||||
if f:
|
||||
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
|
||||
try:
|
||||
await self.save_answer_content(
|
||||
doc_id=answer_doc_id,
|
||||
user=request.user,
|
||||
content=f,
|
||||
title=f"Agent Answer: {request.question[:50]}...",
|
||||
)
|
||||
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
answer_doc_id = None # Fall back to inline content
|
||||
|
||||
final_triples = set_graph(
|
||||
agent_final_triples(final_uri, parent_uri, f),
|
||||
agent_final_triples(
|
||||
final_uri, parent_uri,
|
||||
answer="" if answer_doc_id else f,
|
||||
document_id=answer_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
|
|
@ -518,6 +664,15 @@ class Processor(AgentService):
|
|||
))
|
||||
logger.debug(f"Emitted final triples for {final_uri}")
|
||||
|
||||
# Send explain event for conclusion
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=final_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
if streaming:
|
||||
# Streaming format - send end-of-dialog marker
|
||||
# Answer chunks were already sent via answer() callback during parsing
|
||||
|
|
@ -558,14 +713,48 @@ class Processor(AgentService):
|
|||
else:
|
||||
parent_uri = session_uri
|
||||
|
||||
# Save thought to librarian
|
||||
thought_doc_id = None
|
||||
if act.thought:
|
||||
thought_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
|
||||
try:
|
||||
await self.save_answer_content(
|
||||
doc_id=thought_doc_id,
|
||||
user=request.user,
|
||||
content=act.thought,
|
||||
title=f"Agent Thought: {act.name}",
|
||||
)
|
||||
logger.debug(f"Saved thought to librarian: {thought_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save thought to librarian: {e}")
|
||||
thought_doc_id = None
|
||||
|
||||
# Save observation to librarian
|
||||
observation_doc_id = None
|
||||
if act.observation:
|
||||
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
|
||||
try:
|
||||
await self.save_answer_content(
|
||||
doc_id=observation_doc_id,
|
||||
user=request.user,
|
||||
content=act.observation,
|
||||
title=f"Agent Observation: {act.name}",
|
||||
)
|
||||
logger.debug(f"Saved observation to librarian: {observation_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save observation to librarian: {e}")
|
||||
observation_doc_id = None
|
||||
|
||||
iter_triples = set_graph(
|
||||
agent_iteration_triples(
|
||||
iteration_uri,
|
||||
parent_uri,
|
||||
act.thought,
|
||||
act.name,
|
||||
act.arguments,
|
||||
act.observation,
|
||||
thought="" if thought_doc_id else act.thought,
|
||||
action=act.name,
|
||||
arguments=act.arguments,
|
||||
observation="" if observation_doc_id else act.observation,
|
||||
thought_document_id=thought_doc_id,
|
||||
observation_document_id=observation_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
|
|
@ -579,6 +768,15 @@ class Processor(AgentService):
|
|||
))
|
||||
logger.debug(f"Emitted iteration triples for {iteration_uri}")
|
||||
|
||||
# Send explain event for iteration
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=iteration_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
history.append(act)
|
||||
|
||||
# Handle state transitions if tool execution was successful
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class DocumentRag:
|
|||
async def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
doc_limit=20, streaming=False, chunk_callback=None,
|
||||
explain_callback=None,
|
||||
explain_callback=None, save_answer_callback=None,
|
||||
):
|
||||
"""
|
||||
Execute a Document RAG query with optional explainability tracking.
|
||||
|
|
@ -122,6 +122,7 @@ class DocumentRag:
|
|||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for explainability
|
||||
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
|
|
@ -192,9 +193,28 @@ class DocumentRag:
|
|||
|
||||
# Emit synthesis explainability after answer generated
|
||||
if explain_callback:
|
||||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian if callback provided
|
||||
if save_answer_callback and answer_text:
|
||||
# Generate document ID as URN matching query-time provenance format
|
||||
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
if self.verbose:
|
||||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None # Fall back to inline content
|
||||
|
||||
# Generate triples with document reference or inline content
|
||||
syn_triples = set_graph(
|
||||
docrag_synthesis_triples(syn_uri, exp_uri, answer_text),
|
||||
docrag_synthesis_triples(
|
||||
syn_uri, exp_uri,
|
||||
answer_text="" if synthesis_doc_id else answer_text,
|
||||
document_id=synthesis_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(syn_triples, syn_uri)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,10 @@ import asyncio
|
|||
import base64
|
||||
import logging
|
||||
|
||||
import uuid
|
||||
|
||||
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
||||
from ... schema import LibrarianRequest, LibrarianResponse
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... schema import Triples, Metadata
|
||||
from ... provenance import GRAPH_RETRIEVAL
|
||||
|
|
@ -179,6 +181,62 @@ class Processor(FlowProcessor):
|
|||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
Save answer content to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the answer document
|
||||
user: User ID
|
||||
content: Answer text content
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or "DocumentRAG Answer",
|
||||
document_type="answer",
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
|
@ -222,10 +280,20 @@ class Processor(FlowProcessor):
|
|||
response=None,
|
||||
explain_id=explain_id,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
message_type="explain",
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Callback to save answer content to librarian
|
||||
async def save_answer(doc_id, answer_text):
|
||||
await self.save_answer_content(
|
||||
doc_id=doc_id,
|
||||
user=v.user,
|
||||
content=answer_text,
|
||||
title=f"DocumentRAG Answer: {v.query[:50]}...",
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
|
|
@ -235,6 +303,7 @@ class Processor(FlowProcessor):
|
|||
DocumentRagResponse(
|
||||
response=chunk,
|
||||
end_of_stream=end_of_stream,
|
||||
message_type="chunk",
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
|
|
@ -250,6 +319,17 @@ class Processor(FlowProcessor):
|
|||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
explain_callback=send_explainability,
|
||||
save_answer_callback=save_answer,
|
||||
)
|
||||
|
||||
# Send end_of_session to signal entire session is complete
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response=None,
|
||||
end_of_session=True,
|
||||
message_type="end",
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
|
|
@ -259,6 +339,7 @@ class Processor(FlowProcessor):
|
|||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
explain_callback=send_explainability,
|
||||
save_answer_callback=save_answer,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue