diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py index a92705a0..dc7d5748 100644 --- a/tests/contract/test_translator_completion_flags.py +++ b/tests/contract/test_translator_completion_flags.py @@ -110,16 +110,17 @@ class TestRAGTranslatorCompletionFlags: assert response_dict["end_of_stream"] is True assert response_dict["end_of_session"] is False - def test_document_rag_translator_is_final_with_end_of_stream_true(self): + def test_document_rag_translator_is_final_with_end_of_session_true(self): """ Test that DocumentRagResponseTranslator returns is_final=True - when end_of_stream=True. + when end_of_session=True. """ # Arrange translator = TranslatorRegistry.get_response_translator("document-rag") response = DocumentRagResponse( response="A document about cats.", end_of_stream=True, + end_of_session=True, error=None ) @@ -127,9 +128,31 @@ class TestRAGTranslatorCompletionFlags: response_dict, is_final = translator.from_response_with_completion(response) # Assert - assert is_final is True, "is_final must be True when end_of_stream=True" + assert is_final is True, "is_final must be True when end_of_session=True" assert response_dict["response"] == "A document about cats." + assert response_dict["end_of_session"] is True + + def test_document_rag_translator_end_of_stream_not_final(self): + """ + Test that end_of_stream=True alone does NOT make is_final=True. + The session continues with provenance messages after LLM stream completes. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("document-rag") + response = DocumentRagResponse( + response="Final chunk", + end_of_stream=True, + end_of_session=False, # Session continues with provenance + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False, "end_of_stream=True should NOT make is_final=True" assert response_dict["end_of_stream"] is True + assert response_dict["end_of_session"] is False def test_document_rag_translator_is_final_with_end_of_stream_false(self): """ diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py index f4f59444..0fedd2b5 100644 --- a/tests/integration/test_agent_structured_query_integration.py +++ b/tests/integration/test_agent_structured_query_integration.py @@ -30,10 +30,13 @@ class TestAgentStructuredQueryIntegration: pulsar_client=AsyncMock(), max_iterations=3 ) - + # Mock the client method for structured query proc.client = MagicMock() - + + # Mock librarian to avoid hanging on save operations + proc.save_answer_content = AsyncMock(return_value=None) + return proc @pytest.fixture diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index 0fd2060d..2ef64e96 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -28,6 +28,9 @@ class TestAgentServiceNonStreaming: max_iterations=10 ) + # Mock librarian to avoid hanging on save operations + processor.save_answer_content = AsyncMock(return_value=None) + # Track all responses sent sent_responses = [] @@ -106,6 +109,9 @@ class TestAgentServiceNonStreaming: max_iterations=10 ) + # Mock librarian to avoid hanging on save operations + processor.save_answer_content = AsyncMock(return_value=None) + # Track all responses sent sent_responses = [] @@ -173,6 +179,9 @@ class TestAgentServiceNonStreaming: max_iterations=10 ) + # Mock librarian to avoid hanging on save operations + processor.save_answer_content = AsyncMock(return_value=None) + # Track all responses sent sent_responses = [] diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index d6b5031a..05e1bb60 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -68,6 +68,7 @@ class TestDocumentRagService: collection="test_coll_1", # Must be from message, not hardcoded default doc_limit=5, explain_callback=ANY, # Explainability callback is always passed + save_answer_callback=ANY, # Librarian save callback is always passed ) # Verify response was sent diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index daa2cc5c..e71e192c 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -59,7 +59,7 @@ from .flow import Flow, FlowInstance from .async_flow import AsyncFlow, AsyncFlowInstance # WebSocket clients -from .socket_client import SocketClient, SocketFlowInstance +from .socket_client import SocketClient, SocketFlowInstance, build_term from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance # Bulk operation clients @@ -70,6 +70,21 @@ from .async_bulk_client import AsyncBulkClient from .metrics import Metrics from .async_metrics import AsyncMetrics +# Explainability +from .explainability import ( + ExplainabilityClient, + ExplainEntity, + Question, + Exploration, + Focus, + Synthesis, + Analysis, + Conclusion, + EdgeSelection, + wire_triples_to_tuples, + extract_term_value, +) + # Types from .types import ( Triple, @@ -85,6 +100,7 @@ from .types import ( AgentObservation, AgentAnswer, RAGChunk, + ProvenanceEvent, ) # Exceptions @@ -124,6 +140,7 @@ __all__ = [ "SocketFlowInstance", "AsyncSocketClient", "AsyncSocketFlowInstance", + "build_term", # Bulk operation clients "BulkClient", @@ -133,6 +150,19 @@ __all__ = [ "Metrics", "AsyncMetrics", + # Explainability + "ExplainabilityClient", + "ExplainEntity", + "Question", + "Exploration", + "Focus", + "Synthesis", + "Analysis", + "Conclusion", + "EdgeSelection", + "wire_triples_to_tuples", + "extract_term_value", + # Types "Triple", "Uri", @@ -147,6 +177,7 @@ __all__ = [ "AgentObservation", "AgentAnswer", "RAGChunk", + "ProvenanceEvent", # Exceptions "ProtocolException", diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py new file mode 100644 index 00000000..b7ebca0e --- /dev/null +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -0,0 +1,1132 @@ +""" +Explainability support for TrustGraph API. + +Provides classes for explainability entities (Question, Exploration, Focus, +Synthesis, Analysis, Conclusion) and utilities for fetching them with +eventual consistency handling. +""" + +import asyncio +import time +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any, Tuple, Union + +# Provenance predicates +TG = "https://trustgraph.ai/ns/" +TG_QUERY = TG + "query" +TG_EDGE_COUNT = TG + "edgeCount" +TG_SELECTED_EDGE = TG + "selectedEdge" +TG_EDGE = TG + "edge" +TG_REASONING = TG + "reasoning" +TG_CONTENT = TG + "content" +TG_DOCUMENT = TG + "document" +TG_CHUNK_COUNT = TG + "chunkCount" +TG_SELECTED_CHUNK = TG + "selectedChunk" +TG_THOUGHT = TG + "thought" +TG_ACTION = TG + "action" +TG_ARGUMENTS = TG + "arguments" +TG_OBSERVATION = TG + "observation" +TG_ANSWER = TG + "answer" +TG_THOUGHT_DOCUMENT = TG + "thoughtDocument" +TG_OBSERVATION_DOCUMENT = TG + "observationDocument" + +# Entity types +TG_QUESTION = TG + "Question" +TG_EXPLORATION = TG + "Exploration" +TG_FOCUS = TG + "Focus" +TG_SYNTHESIS = TG + "Synthesis" +TG_ANALYSIS = TG + "Analysis" +TG_CONCLUSION = TG + "Conclusion" +TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" +TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" +TG_AGENT_QUESTION = TG + "AgentQuestion" + +# PROV-O predicates +PROV = "http://www.w3.org/ns/prov#" +PROV_STARTED_AT_TIME = PROV + "startedAtTime" +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" +PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy" + +RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" +RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" + + +@dataclass +class EdgeSelection: + """A selected edge with reasoning from GraphRAG Focus step.""" + uri: str + edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...} + reasoning: str = "" + + +@dataclass +class ExplainEntity: + """Base class for explainability entities.""" + uri: str + entity_type: str = "" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "ExplainEntity": + """Parse triples into the appropriate entity type.""" + # Determine entity type from rdf:type triples + types = [o for s, p, o in triples if p == RDF_TYPE] + + if TG_GRAPH_RAG_QUESTION in types or TG_DOC_RAG_QUESTION in types or TG_AGENT_QUESTION in types: + return Question.from_triples(uri, triples, types) + elif TG_EXPLORATION in types: + return Exploration.from_triples(uri, triples) + elif TG_FOCUS in types: + return Focus.from_triples(uri, triples) + elif TG_SYNTHESIS in types: + return Synthesis.from_triples(uri, triples) + elif TG_ANALYSIS in types: + return Analysis.from_triples(uri, triples) + elif TG_CONCLUSION in types: + return Conclusion.from_triples(uri, triples) + else: + # Generic entity + return ExplainEntity(uri=uri, entity_type="unknown") + + +@dataclass +class Question(ExplainEntity): + """Question entity - the user's query that started the session.""" + query: str = "" + timestamp: str = "" + question_type: str = "" # "graph-rag", "document-rag", "agent" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]], + types: List[str]) -> "Question": + query = "" + timestamp = "" + question_type = "unknown" + + for s, p, o in triples: + if p == TG_QUERY: + query = o + elif p == PROV_STARTED_AT_TIME: + timestamp = o + + if TG_GRAPH_RAG_QUESTION in types: + question_type = "graph-rag" + elif TG_DOC_RAG_QUESTION in types: + question_type = "document-rag" + elif TG_AGENT_QUESTION in types: + question_type = "agent" + + return cls( + uri=uri, + entity_type="question", + query=query, + timestamp=timestamp, + question_type=question_type + ) + + +@dataclass +class Exploration(ExplainEntity): + """Exploration entity - edges/chunks retrieved from the knowledge store.""" + edge_count: int = 0 + chunk_count: int = 0 + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Exploration": + edge_count = 0 + chunk_count = 0 + + for s, p, o in triples: + if p == TG_EDGE_COUNT: + try: + edge_count = int(o) + except (ValueError, TypeError): + pass + elif p == TG_CHUNK_COUNT: + try: + chunk_count = int(o) + except (ValueError, TypeError): + pass + + return cls( + uri=uri, + entity_type="exploration", + edge_count=edge_count, + chunk_count=chunk_count + ) + + +@dataclass +class Focus(ExplainEntity): + """Focus entity - selected edges with LLM reasoning (GraphRAG only).""" + selected_edge_uris: List[str] = field(default_factory=list) + edge_selections: List[EdgeSelection] = field(default_factory=list) + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Focus": + selected_edge_uris = [] + + for s, p, o in triples: + if p == TG_SELECTED_EDGE and isinstance(o, str): + selected_edge_uris.append(o) + + return cls( + uri=uri, + entity_type="focus", + selected_edge_uris=selected_edge_uris, + edge_selections=[] # Populated separately by fetching each edge URI + ) + + +@dataclass +class Synthesis(ExplainEntity): + """Synthesis entity - the final answer.""" + content: str = "" + document_uri: str = "" # Reference to librarian document + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Synthesis": + content = "" + document_uri = "" + + for s, p, o in triples: + if p == TG_CONTENT: + content = o + elif p == TG_DOCUMENT: + document_uri = o + + return cls( + uri=uri, + entity_type="synthesis", + content=content, + document_uri=document_uri + ) + + +@dataclass +class Analysis(ExplainEntity): + """Analysis entity - one think/act/observe cycle (Agent only).""" + thought: str = "" + action: str = "" + arguments: str = "" # JSON string + observation: str = "" + thought_document_uri: str = "" # Reference to thought in librarian + observation_document_uri: str = "" # Reference to observation in librarian + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis": + thought = "" + action = "" + arguments = "" + observation = "" + thought_document_uri = "" + observation_document_uri = "" + + for s, p, o in triples: + if p == TG_THOUGHT: + thought = o + elif p == TG_ACTION: + action = o + elif p == TG_ARGUMENTS: + arguments = o + elif p == TG_OBSERVATION: + observation = o + elif p == TG_THOUGHT_DOCUMENT: + thought_document_uri = o + elif p == TG_OBSERVATION_DOCUMENT: + observation_document_uri = o + + return cls( + uri=uri, + entity_type="analysis", + thought=thought, + action=action, + arguments=arguments, + observation=observation, + thought_document_uri=thought_document_uri, + observation_document_uri=observation_document_uri + ) + + +@dataclass +class Conclusion(ExplainEntity): + """Conclusion entity - final answer (Agent only).""" + answer: str = "" + document_uri: str = "" # Reference to librarian document + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Conclusion": + answer = "" + document_uri = "" + + for s, p, o in triples: + if p == TG_ANSWER: + answer = o + elif p == TG_DOCUMENT: + document_uri = o + + return cls( + uri=uri, + entity_type="conclusion", + answer=answer, + document_uri=document_uri + ) + + +def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSelection: + """Parse triples for an edge selection entity.""" + uri = triples[0][0] if triples else "" + edge = None + reasoning = "" + + for s, p, o in triples: + if p == TG_EDGE and isinstance(o, dict): + edge = o + elif p == TG_REASONING: + reasoning = o + + return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning) + + +def extract_term_value(term: Dict[str, Any]) -> Any: + """Extract value from a wire-format Term dict.""" + t = term.get("t") or term.get("type") + + if t == "i": + return term.get("i") or term.get("iri", "") + elif t == "l": + return term.get("v") or term.get("value", "") + elif t == "t": + # Quoted triple - return as dict + tr = term.get("tr") or term.get("triple", {}) + return { + "s": extract_term_value(tr.get("s", {})), + "p": extract_term_value(tr.get("p", {})), + "o": extract_term_value(tr.get("o", {})), + } + else: + # Unknown format, try common keys + return term.get("i") or term.get("v") or term.get("iri") or term.get("value") or str(term) + + +def wire_triples_to_tuples(wire_triples: List[Dict[str, Any]]) -> List[Tuple[str, str, Any]]: + """Convert wire-format triples to (s, p, o) tuples.""" + result = [] + for t in wire_triples: + s = extract_term_value(t.get("s", {})) + p = extract_term_value(t.get("p", {})) + o = extract_term_value(t.get("o", {})) + result.append((s, p, o)) + return result + + +class ExplainabilityClient: + """ + Client for fetching explainability entities with eventual consistency handling. + + Uses quiescence detection: fetch, wait, fetch again, compare. + If results are the same, data is stable. + """ + + def __init__(self, flow_instance, retry_delay: float = 0.2, max_retries: int = 10): + """ + Initialize explainability client. + + Args: + flow_instance: A SocketFlowInstance for querying triples + retry_delay: Delay between retries in seconds (default: 0.2) + max_retries: Maximum retry attempts (default: 10) + """ + self.flow = flow_instance + self.retry_delay = retry_delay + self.max_retries = max_retries + self._label_cache: Dict[str, str] = {} + + def fetch_entity( + self, + uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> Optional[ExplainEntity]: + """ + Fetch an explainability entity by URI with eventual consistency handling. + + Uses quiescence detection: + 1. Fetch triples for URI + 2. If zero results, retry + 3. If non-zero results, wait and fetch again + 4. If same results, data is stable - parse and return + 5. If different results, data still being written - retry + + Args: + uri: The entity URI to fetch + graph: Named graph to query (e.g., "urn:graph:retrieval") + user: User/keyspace identifier + collection: Collection identifier + + Returns: + ExplainEntity subclass or None if not found + """ + prev_triples = None + + for attempt in range(self.max_retries): + # Fetch triples for this URI + wire_triples = self.flow.triples_query( + s=uri, + g=graph, + user=user, + collection=collection, + limit=100 + ) + + if not wire_triples: + # Zero results - definitely retry + time.sleep(self.retry_delay) + continue + + # Convert to comparable format + triples = wire_triples_to_tuples(wire_triples) + triples_set = frozenset((s, p, str(o)) for s, p, o in triples) + + if prev_triples is None: + # First non-empty result - wait and check for stability + prev_triples = triples_set + time.sleep(self.retry_delay) + continue + + if triples_set == prev_triples: + # Same as before - data is stable + return ExplainEntity.from_triples(uri, triples) + else: + # Different - still being written, update and retry + prev_triples = triples_set + time.sleep(self.retry_delay) + continue + + # Max retries reached - return what we have if anything + if prev_triples: + # Re-fetch and parse + wire_triples = self.flow.triples_query( + s=uri, g=graph, user=user, collection=collection, limit=100 + ) + if wire_triples: + triples = wire_triples_to_tuples(wire_triples) + return ExplainEntity.from_triples(uri, triples) + + return None + + def fetch_edge_selection( + self, + uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> Optional[EdgeSelection]: + """ + Fetch an edge selection entity (used by Focus). + + Args: + uri: The edge selection URI + graph: Named graph to query + user: User/keyspace identifier + collection: Collection identifier + + Returns: + EdgeSelection or None if not found + """ + wire_triples = self.flow.triples_query( + s=uri, + g=graph, + user=user, + collection=collection, + limit=100 + ) + + if not wire_triples: + return None + + triples = wire_triples_to_tuples(wire_triples) + return parse_edge_selection_triples(triples) + + def fetch_focus_with_edges( + self, + uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> Optional[Focus]: + """ + Fetch a Focus entity and all its edge selections. + + Args: + uri: The Focus entity URI + graph: Named graph to query + user: User/keyspace identifier + collection: Collection identifier + + Returns: + Focus with populated edge_selections, or None + """ + entity = self.fetch_entity(uri, graph, user, collection) + + if not isinstance(entity, Focus): + return None + + # Fetch each edge selection + for edge_uri in entity.selected_edge_uris: + edge_sel = self.fetch_edge_selection(edge_uri, graph, user, collection) + if edge_sel: + entity.edge_selections.append(edge_sel) + + return entity + + def resolve_label( + self, + uri: str, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> str: + """ + Resolve rdfs:label for a URI, with caching. + + Args: + uri: The URI to get label for + user: User/keyspace identifier + collection: Collection identifier + + Returns: + The label if found, otherwise the URI itself + """ + if not uri or not uri.startswith(("http://", "https://", "urn:")): + return uri + + if uri in self._label_cache: + return self._label_cache[uri] + + wire_triples = self.flow.triples_query( + s=uri, + p=RDFS_LABEL, + user=user, + collection=collection, + limit=1 + ) + + if wire_triples: + triples = wire_triples_to_tuples(wire_triples) + if triples: + label = triples[0][2] + self._label_cache[uri] = label + return label + + self._label_cache[uri] = uri + return uri + + def resolve_edge_labels( + self, + edge: Dict[str, str], + user: Optional[str] = None, + collection: Optional[str] = None + ) -> Tuple[str, str, str]: + """ + Resolve labels for all components of an edge triple. + + Args: + edge: Dict with "s", "p", "o" keys + user: User/keyspace identifier + collection: Collection identifier + + Returns: + Tuple of (s_label, p_label, o_label) + """ + s_label = self.resolve_label(edge.get("s", ""), user, collection) + p_label = self.resolve_label(edge.get("p", ""), user, collection) + o_label = self.resolve_label(edge.get("o", ""), user, collection) + return (s_label, p_label, o_label) + + def fetch_synthesis_content( + self, + synthesis: Synthesis, + api: Any, + user: Optional[str] = None, + max_content: int = 10000 + ) -> str: + """ + Fetch the content for a Synthesis entity. + + If synthesis has inline content, returns that. + If synthesis has a document_uri, fetches from librarian with retry. + + Args: + synthesis: The Synthesis entity + api: TrustGraph Api instance for librarian access + user: User identifier for librarian + max_content: Maximum content length to return + + Returns: + The synthesis content as a string + """ + # If inline content exists, use it + if synthesis.content: + if len(synthesis.content) > max_content: + return synthesis.content[:max_content] + "... [truncated]" + return synthesis.content + + # Otherwise fetch from librarian + if not synthesis.document_uri: + return "" + + # Extract document ID from URI (e.g., "urn:document:abc123" -> "abc123") + doc_id = synthesis.document_uri + if doc_id.startswith("urn:document:"): + doc_id = doc_id[len("urn:document:"):] + + # Retry fetching from librarian for eventual consistency + for attempt in range(self.max_retries): + try: + library = api.library() + content_bytes = library.get_document_content(user=user, id=doc_id) + + # Decode as text + try: + content = content_bytes.decode('utf-8') + if len(content) > max_content: + return content[:max_content] + "... [truncated]" + return content + except UnicodeDecodeError: + return f"[Binary: {len(content_bytes)} bytes]" + + except Exception as e: + if attempt < self.max_retries - 1: + time.sleep(self.retry_delay) + continue + return f"[Error fetching content: {e}]" + + return "" + + def fetch_conclusion_content( + self, + conclusion: Conclusion, + api: Any, + user: Optional[str] = None, + max_content: int = 10000 + ) -> str: + """ + Fetch the content for a Conclusion entity (Agent final answer). + + If conclusion has inline answer, returns that. + If conclusion has a document_uri, fetches from librarian with retry. + + Args: + conclusion: The Conclusion entity + api: TrustGraph Api instance for librarian access + user: User identifier for librarian + max_content: Maximum content length to return + + Returns: + The conclusion answer as a string + """ + # If inline answer exists, use it + if conclusion.answer: + if len(conclusion.answer) > max_content: + return conclusion.answer[:max_content] + "... [truncated]" + return conclusion.answer + + # Otherwise fetch from librarian + if not conclusion.document_uri: + return "" + + # Use document URI directly (it's already a full URN) + doc_id = conclusion.document_uri + + # Retry fetching from librarian for eventual consistency + for attempt in range(self.max_retries): + try: + library = api.library() + content_bytes = library.get_document_content(user=user, id=doc_id) + + # Decode as text + try: + content = content_bytes.decode('utf-8') + if len(content) > max_content: + return content[:max_content] + "... [truncated]" + return content + except UnicodeDecodeError: + return f"[Binary: {len(content_bytes)} bytes]" + + except Exception as e: + if attempt < self.max_retries - 1: + time.sleep(self.retry_delay) + continue + return f"[Error fetching content: {e}]" + + return "" + + def fetch_analysis_content( + self, + analysis: Analysis, + api: Any, + user: Optional[str] = None, + max_content: int = 10000 + ) -> None: + """ + Fetch thought and observation content for an Analysis entity. + + If analysis has inline content, uses that. + If analysis has document URIs, fetches from librarian with retry. + Modifies the analysis object in place. + + Args: + analysis: The Analysis entity (modified in place) + api: TrustGraph Api instance for librarian access + user: User identifier for librarian + max_content: Maximum content length to return + """ + # Fetch thought if needed + if not analysis.thought and analysis.thought_document_uri: + doc_id = analysis.thought_document_uri + for attempt in range(self.max_retries): + try: + library = api.library() + content_bytes = library.get_document_content(user=user, id=doc_id) + try: + content = content_bytes.decode('utf-8') + if len(content) > max_content: + analysis.thought = content[:max_content] + "... [truncated]" + else: + analysis.thought = content + break + except UnicodeDecodeError: + analysis.thought = f"[Binary: {len(content_bytes)} bytes]" + break + except Exception as e: + if attempt < self.max_retries - 1: + time.sleep(self.retry_delay) + continue + analysis.thought = f"[Error fetching thought: {e}]" + + # Fetch observation if needed + if not analysis.observation and analysis.observation_document_uri: + doc_id = analysis.observation_document_uri + for attempt in range(self.max_retries): + try: + library = api.library() + content_bytes = library.get_document_content(user=user, id=doc_id) + try: + content = content_bytes.decode('utf-8') + if len(content) > max_content: + analysis.observation = content[:max_content] + "... [truncated]" + else: + analysis.observation = content + break + except UnicodeDecodeError: + analysis.observation = f"[Binary: {len(content_bytes)} bytes]" + break + except Exception as e: + if attempt < self.max_retries - 1: + time.sleep(self.retry_delay) + continue + analysis.observation = f"[Error fetching observation: {e}]" + + def fetch_graphrag_trace( + self, + question_uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + api: Any = None, + max_content: int = 10000 + ) -> Dict[str, Any]: + """ + Fetch the complete GraphRAG trace starting from a question URI. + + Follows the provenance chain: Question -> Exploration -> Focus -> Synthesis + + Args: + question_uri: The question entity URI + graph: Named graph (default: urn:graph:retrieval) + user: User/keyspace identifier + collection: Collection identifier + api: TrustGraph Api instance for librarian access (optional) + max_content: Maximum content length for synthesis + + Returns: + Dict with question, exploration, focus, synthesis entities + """ + if graph is None: + graph = "urn:graph:retrieval" + + trace = { + "question": None, + "exploration": None, + "focus": None, + "synthesis": None, + } + + # Fetch question + question = self.fetch_entity(question_uri, graph, user, collection) + if not isinstance(question, Question): + return trace + trace["question"] = question + + # Find exploration: ?exploration prov:wasGeneratedBy question_uri + exploration_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=question_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if exploration_triples: + exploration_uris = [ + extract_term_value(t.get("s", {})) + for t in exploration_triples + ] + for exp_uri in exploration_uris: + exploration = self.fetch_entity(exp_uri, graph, user, collection) + if isinstance(exploration, Exploration): + trace["exploration"] = exploration + break + + if not trace["exploration"]: + return trace + + # Find focus: ?focus prov:wasDerivedFrom exploration_uri + focus_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["exploration"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if focus_triples: + focus_uris = [ + extract_term_value(t.get("s", {})) + for t in focus_triples + ] + for focus_uri in focus_uris: + focus = self.fetch_focus_with_edges(focus_uri, graph, user, collection) + if focus: + trace["focus"] = focus + break + + if not trace["focus"]: + return trace + + # Find synthesis: ?synthesis prov:wasDerivedFrom focus_uri + synthesis_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["focus"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if synthesis_triples: + synthesis_uris = [ + extract_term_value(t.get("s", {})) + for t in synthesis_triples + ] + for synth_uri in synthesis_uris: + synthesis = self.fetch_entity(synth_uri, graph, user, collection) + if isinstance(synthesis, Synthesis): + # Fetch content if needed + if api and not synthesis.content and synthesis.document_uri: + synthesis.content = self.fetch_synthesis_content( + synthesis, api, user, max_content + ) + trace["synthesis"] = synthesis + break + + return trace + + def fetch_docrag_trace( + self, + question_uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + api: Any = None, + max_content: int = 10000 + ) -> Dict[str, Any]: + """ + Fetch the complete DocumentRAG trace starting from a question URI. + + Follows the provenance chain: Question -> Exploration -> Synthesis + (No Focus step for DocRAG since it doesn't do edge selection) + + Args: + question_uri: The question entity URI + graph: Named graph (default: urn:graph:retrieval) + user: User/keyspace identifier + collection: Collection identifier + api: TrustGraph Api instance for librarian access (optional) + max_content: Maximum content length for synthesis + + Returns: + Dict with question, exploration, synthesis entities + """ + if graph is None: + graph = "urn:graph:retrieval" + + trace = { + "question": None, + "exploration": None, + "synthesis": None, + } + + # Fetch question + question = self.fetch_entity(question_uri, graph, user, collection) + if not isinstance(question, Question): + return trace + trace["question"] = question + + # Find exploration: ?exploration prov:wasGeneratedBy question_uri + exploration_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=question_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if exploration_triples: + exploration_uris = [ + extract_term_value(t.get("s", {})) + for t in exploration_triples + ] + for exp_uri in exploration_uris: + exploration = self.fetch_entity(exp_uri, graph, user, collection) + if isinstance(exploration, Exploration): + trace["exploration"] = exploration + break + + if not trace["exploration"]: + return trace + + # Find synthesis: ?synthesis prov:wasDerivedFrom exploration_uri + # (DocRAG goes directly from exploration to synthesis, no focus step) + synthesis_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["exploration"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if synthesis_triples: + synthesis_uris = [ + extract_term_value(t.get("s", {})) + for t in synthesis_triples + ] + for synth_uri in synthesis_uris: + synthesis = self.fetch_entity(synth_uri, graph, user, collection) + if isinstance(synthesis, Synthesis): + # Fetch content if needed + if api and not synthesis.content and synthesis.document_uri: + synthesis.content = self.fetch_synthesis_content( + synthesis, api, user, max_content + ) + trace["synthesis"] = synthesis + break + + return trace + + def fetch_agent_trace( + self, + session_uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + api: Any = None, + max_content: int = 10000 + ) -> Dict[str, Any]: + """ + Fetch the complete Agent trace starting from a session URI. + + Follows the provenance chain: Question -> Analysis(s) -> Conclusion + + Args: + session_uri: The agent session/question URI + graph: Named graph (default: urn:graph:retrieval) + user: User/keyspace identifier + collection: Collection identifier + api: TrustGraph Api instance for librarian access (optional) + max_content: Maximum content length for conclusion + + Returns: + Dict with question, iterations (Analysis list), conclusion entities + """ + if graph is None: + graph = "urn:graph:retrieval" + + trace = { + "question": None, + "iterations": [], + "conclusion": None, + } + + # Fetch question/session + question = self.fetch_entity(session_uri, graph, user, collection) + if not isinstance(question, Question): + return trace + trace["question"] = question + + # Follow the chain of wasDerivedFrom + current_uri = session_uri + max_iterations = 50 # Safety limit + + for _ in range(max_iterations): + # Find entity derived from current + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=current_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if not derived_triples: + break + + derived_uri = extract_term_value(derived_triples[0].get("s", {})) + if not derived_uri: + break + + entity = self.fetch_entity(derived_uri, graph, user, collection) + + if isinstance(entity, Analysis): + # Fetch thought/observation content from librarian if needed + if api: + self.fetch_analysis_content( + entity, api, user=user, max_content=max_content + ) + trace["iterations"].append(entity) + current_uri = derived_uri + elif isinstance(entity, Conclusion): + # Fetch answer content from librarian if needed + if api and not entity.answer and entity.document_uri: + entity.answer = self.fetch_conclusion_content( + entity, api, user=user, max_content=max_content + ) + trace["conclusion"] = entity + break + else: + # Unknown entity type, stop + break + + return trace + + def list_sessions( + self, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + limit: int = 50 + ) -> List[Question]: + """ + List all explainability sessions (questions) in a collection. + + Args: + graph: Named graph (default: urn:graph:retrieval) + user: User/keyspace identifier + collection: Collection identifier + limit: Maximum number of sessions to return + + Returns: + List of Question entities sorted by timestamp (newest first) + """ + if graph is None: + graph = "urn:graph:retrieval" + + # Query for all triples with predicate = tg:query + query_triples = self.flow.triples_query( + p=TG_QUERY, + g=graph, + user=user, + collection=collection, + limit=limit + ) + + questions = [] + for t in query_triples: + question_uri = extract_term_value(t.get("s", {})) + if question_uri: + entity = self.fetch_entity(question_uri, graph, user, collection) + if isinstance(entity, Question): + questions.append(entity) + + # Sort by timestamp (newest first) + questions.sort(key=lambda q: q.timestamp or "", reverse=True) + + return questions + + def detect_session_type( + self, + session_uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> str: + """ + Detect whether a session is GraphRAG or Agent type. + + Args: + session_uri: The session/question URI + graph: Named graph + user: User/keyspace identifier + collection: Collection identifier + + Returns: + "graphrag" or "agent" + """ + if graph is None: + graph = "urn:graph:retrieval" + + # Fast path: check URI pattern + if "agent" in session_uri: + return "agent" + if "question" in session_uri: + return "graphrag" + if "docrag" in session_uri: + return "docrag" + + # Check what's derived from this entity + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=session_uri, + g=graph, + user=user, + collection=collection, + limit=5 + ) + + generated_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=session_uri, + g=graph, + user=user, + collection=collection, + limit=5 + ) + + all_child_uris = [ + extract_term_value(t.get("s", {})) + for t in (derived_triples + generated_triples) + ] + + for child_uri in all_child_uris: + entity = self.fetch_entity(child_uri, graph, user, collection) + if isinstance(entity, Analysis): + return "agent" + if isinstance(entity, Exploration): + return "graphrag" + + return "graphrag" # Default diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index a08b8bca..4e09351a 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -15,6 +15,63 @@ from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, Strea from . exceptions import ProtocolException, raise_from_error_dict +def build_term(value: Any, term_type: Optional[str] = None, + datatype: Optional[str] = None, language: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Build wire-format Term dict from a value. + + Auto-detection rules (when term_type is None): + - Already a dict with 't' key -> return as-is (already a Term) + - Starts with http://, https://, urn: -> IRI + - Wrapped in <> (e.g., ) -> IRI (angle brackets stripped) + - Anything else -> literal + + Args: + value: The term value (string, dict, or None) + term_type: One of 'iri', 'literal', or None for auto-detect + datatype: Datatype for literal objects (e.g., xsd:integer) + language: Language tag for literal objects (e.g., en) + + Returns: + dict: Wire-format Term dict, or None if value is None + """ + if value is None: + return None + + # If already a Term dict, return as-is + if isinstance(value, dict) and "t" in value: + return value + + # Convert to string for processing + value = str(value) + + # Auto-detect type if not specified + if term_type is None: + if value.startswith("<") and value.endswith(">") and not value.startswith("<<"): + # Angle-bracket wrapped IRI: + value = value[1:-1] # Strip < and > + term_type = "iri" + elif value.startswith(("http://", "https://", "urn:")): + term_type = "iri" + else: + term_type = "literal" + + if term_type == "iri": + # Strip angle brackets if present + if value.startswith("<") and value.endswith(">"): + value = value[1:-1] + return {"t": "i", "i": value} + elif term_type == "literal": + result = {"t": "l", "v": value} + if datatype: + result["dt"] = datatype + if language: + result["ln"] = language + return result + else: + raise ValueError(f"Unknown term type: {term_type}") + + class SocketClient: """ Synchronous WebSocket client for streaming operations. @@ -92,7 +149,8 @@ class SocketClient: flow: Optional[str], request: Dict[str, Any], streaming: bool = False, - streaming_raw: bool = False + streaming_raw: bool = False, + include_provenance: bool = False ) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]: """Synchronous wrapper around async WebSocket communication. @@ -119,7 +177,7 @@ class SocketClient: return self._streaming_generator_raw(service, flow, request, loop) elif streaming: # Parsed streaming for agent/RAG chunk types - return self._streaming_generator(service, flow, request, loop) + return self._streaming_generator(service, flow, request, loop, include_provenance) else: # Non-streaming single response return loop.run_until_complete(self._send_request_async(service, flow, request)) @@ -129,10 +187,11 @@ class SocketClient: service: str, flow: Optional[str], request: Dict[str, Any], - loop: asyncio.AbstractEventLoop + loop: asyncio.AbstractEventLoop, + include_provenance: bool = False ) -> Iterator[StreamingChunk]: """Generator that yields streaming chunks (for agent/RAG responses)""" - async_gen = self._send_request_async_streaming(service, flow, request) + async_gen = self._send_request_async_streaming(service, flow, request, include_provenance) try: while True: @@ -265,7 +324,8 @@ class SocketClient: self, service: str, flow: Optional[str], - request: Dict[str, Any] + request: Dict[str, Any], + include_provenance: bool = False ) -> Iterator[StreamingChunk]: """Async implementation of WebSocket request (streaming)""" # Generate unique request ID @@ -309,8 +369,8 @@ class SocketClient: raise_from_error_dict(resp["error"]) # Parse different chunk types - chunk = self._parse_chunk(resp) - if chunk is not None: # Skip provenance messages in streaming + chunk = self._parse_chunk(resp, include_provenance=include_provenance) + if chunk is not None: # Skip provenance messages unless include_provenance yield chunk # Check if this is the final message @@ -325,14 +385,26 @@ class SocketClient: chunk_type = resp.get("chunk_type") message_type = resp.get("message_type") - # Handle new GraphRAG message format with message_type - if message_type == "provenance": + # Handle GraphRAG/DocRAG message format with message_type + if message_type == "explain": if include_provenance: # Return provenance event for explainability - return ProvenanceEvent(provenance_id=resp.get("provenance_id", "")) + return ProvenanceEvent( + explain_id=resp.get("explain_id", ""), + explain_graph=resp.get("explain_graph", "") + ) # Provenance messages are not yielded to user - they're metadata return None + # Handle Agent message format with chunk_type="explain" + if chunk_type == "explain": + if include_provenance: + return ProvenanceEvent( + explain_id=resp.get("explain_id", ""), + explain_graph=resp.get("explain_graph", "") + ) + return None + if chunk_type == "thought": return AgentThought( content=resp.get("content", ""), @@ -477,6 +549,95 @@ class SocketFlowInstance: # regardless of streaming flag, so always use the streaming code path return self.client._send_request_sync("agent", self.flow_id, request, streaming=True) + def agent_explain( + self, + question: str, + user: str, + collection: str, + state: Optional[Dict[str, Any]] = None, + group: Optional[str] = None, + history: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any + ) -> Iterator[Union[StreamingChunk, ProvenanceEvent]]: + """ + Execute an agent operation with explainability support. + + Streams both content chunks (AgentThought, AgentObservation, AgentAnswer) + and provenance events (ProvenanceEvent). Provenance events contain URIs + that can be fetched using ExplainabilityClient to get detailed information + about the agent's reasoning process. + + Agent trace consists of: + - Session: The initial question and session metadata + - Iterations: Each thought/action/observation cycle + - Conclusion: The final answer + + Args: + question: User question or instruction + user: User identifier + collection: Collection identifier for provenance storage + state: Optional state dictionary for stateful conversations + group: Optional group identifier for multi-user contexts + history: Optional conversation history as list of message dicts + **kwargs: Additional parameters passed to the agent service + + Yields: + Union[StreamingChunk, ProvenanceEvent]: Agent chunks and provenance events + + Example: + ```python + from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent + from trustgraph.api import AgentThought, AgentObservation, AgentAnswer + + socket = api.socket() + flow = socket.flow("default") + explain_client = ExplainabilityClient(flow) + + provenance_ids = [] + for item in flow.agent_explain( + question="What is the capital of France?", + user="trustgraph", + collection="default" + ): + if isinstance(item, AgentThought): + print(f"[Thought] {item.content}") + elif isinstance(item, AgentObservation): + print(f"[Observation] {item.content}") + elif isinstance(item, AgentAnswer): + print(f"[Answer] {item.content}") + elif isinstance(item, ProvenanceEvent): + provenance_ids.append(item.explain_id) + + # Fetch session trace after completion + if provenance_ids: + trace = explain_client.fetch_agent_trace( + provenance_ids[0], # Session URI is first + graph="urn:graph:retrieval", + user="trustgraph", + collection="default" + ) + ``` + """ + request = { + "question": question, + "user": user, + "collection": collection, + "streaming": True # Always streaming for explain + } + if state is not None: + request["state"] = state + if group is not None: + request["group"] = group + if history is not None: + request["history"] = history + request.update(kwargs) + + # Use streaming with provenance enabled + return self.client._send_request_sync( + "agent", self.flow_id, request, + streaming=True, include_provenance=True + ) + def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: """ Execute text completion with optional streaming. @@ -596,6 +757,86 @@ class SocketFlowInstance: else: return result.get("response", "") + def graph_rag_explain( + self, + query: str, + user: str, + collection: str, + max_subgraph_size: int = 1000, + max_subgraph_count: int = 5, + max_entity_distance: int = 3, + **kwargs: Any + ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: + """ + Execute graph-based RAG query with explainability support. + + Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent). + Provenance events contain URIs that can be fetched using ExplainabilityClient + to get detailed information about how the response was generated. + + Args: + query: Natural language query + user: User/keyspace identifier + collection: Collection identifier + max_subgraph_size: Maximum total triples in subgraph (default: 1000) + max_subgraph_count: Maximum number of subgraphs (default: 5) + max_entity_distance: Maximum traversal depth (default: 3) + **kwargs: Additional parameters passed to the service + + Yields: + Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events + + Example: + ```python + from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent + + socket = api.socket() + flow = socket.flow("default") + explain_client = ExplainabilityClient(flow) + + provenance_ids = [] + response_text = "" + + for item in flow.graph_rag_explain( + query="Tell me about Marie Curie", + user="trustgraph", + collection="scientists" + ): + if isinstance(item, RAGChunk): + response_text += item.content + print(item.content, end='', flush=True) + elif isinstance(item, ProvenanceEvent): + provenance_ids.append(item.provenance_id) + + # Fetch explainability details + for prov_id in provenance_ids: + entity = explain_client.fetch_entity( + prov_id, + graph="urn:graph:retrieval", + user="trustgraph", + collection="scientists" + ) + print(f"Entity: {entity}") + ``` + """ + request = { + "query": query, + "user": user, + "collection": collection, + "max-subgraph-size": max_subgraph_size, + "max-subgraph-count": max_subgraph_count, + "max-entity-distance": max_entity_distance, + "streaming": True, + "explainable": True, # Enable explainability mode + } + request.update(kwargs) + + # Use streaming with provenance events included + return self.client._send_request_sync( + "graph-rag", self.flow_id, request, + streaming=True, include_provenance=True + ) + def document_rag( self, query: str, @@ -654,6 +895,79 @@ class SocketFlowInstance: else: return result.get("response", "") + def document_rag_explain( + self, + query: str, + user: str, + collection: str, + doc_limit: int = 10, + **kwargs: Any + ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: + """ + Execute document-based RAG query with explainability support. + + Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent). + Provenance events contain URIs that can be fetched using ExplainabilityClient + to get detailed information about how the response was generated. + + Document RAG trace consists of: + - Question: The user's query + - Exploration: Chunks retrieved from document store (chunk_count) + - Synthesis: The generated answer + + Args: + query: Natural language query + user: User/keyspace identifier + collection: Collection identifier + doc_limit: Maximum document chunks to retrieve (default: 10) + **kwargs: Additional parameters passed to the service + + Yields: + Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events + + Example: + ```python + from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent + + socket = api.socket() + flow = socket.flow("default") + explain_client = ExplainabilityClient(flow) + + for item in flow.document_rag_explain( + query="Summarize the key findings", + user="trustgraph", + collection="research-papers", + doc_limit=5 + ): + if isinstance(item, RAGChunk): + print(item.content, end='', flush=True) + elif isinstance(item, ProvenanceEvent): + # Fetch entity details + entity = explain_client.fetch_entity( + item.explain_id, + graph=item.explain_graph, + user="trustgraph", + collection="research-papers" + ) + print(f"Event: {entity}", file=sys.stderr) + ``` + """ + request = { + "query": query, + "user": user, + "collection": collection, + "doc-limit": doc_limit, + "streaming": True, + "explainable": True, + } + request.update(kwargs) + + # Use streaming with provenance events included + return self.client._send_request_sync( + "document-rag", self.flow_id, request, + streaming=True, include_provenance=True + ) + def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: """Generator for RAG streaming (graph-rag and document-rag)""" for chunk in result: @@ -831,28 +1145,30 @@ class SocketFlowInstance: def triples_query( self, - s: Optional[str] = None, - p: Optional[str] = None, - o: Optional[str] = None, + s: Optional[Union[str, Dict[str, Any]]] = None, + p: Optional[Union[str, Dict[str, Any]]] = None, + o: Optional[Union[str, Dict[str, Any]]] = None, + g: Optional[str] = None, user: Optional[str] = None, collection: Optional[str] = None, limit: int = 100, **kwargs: Any - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: """ Query knowledge graph triples using pattern matching. Args: - s: Subject URI (optional, use None for wildcard) - p: Predicate URI (optional, use None for wildcard) - o: Object URI or Literal (optional, use None for wildcard) + s: Subject filter - URI string, Term dict, or None for wildcard + p: Predicate filter - URI string, Term dict, or None for wildcard + o: Object filter - URI/literal string, Term dict, or None for wildcard + g: Named graph filter - URI string or None for all graphs user: User/keyspace identifier (optional) collection: Collection identifier (optional) limit: Maximum results to return (default: 100) **kwargs: Additional parameters passed to the service Returns: - dict: Query results with matching triples + List[Dict]: List of matching triples in wire format Example: ```python @@ -860,33 +1176,54 @@ class SocketFlowInstance: flow = socket.flow("default") # Find all triples about a specific subject - result = flow.triples_query( + triples = flow.triples_query( s="http://example.org/person/marie-curie", user="trustgraph", collection="scientists" ) + + # Query with named graph filter + triples = flow.triples_query( + s="urn:trustgraph:session:abc123", + g="urn:graph:retrieval", + user="trustgraph", + collection="default" + ) ``` """ request = {"limit": limit} - if s is not None: - request["s"] = str(s) - if p is not None: - request["p"] = str(p) - if o is not None: - request["o"] = str(o) + + # Build Term dicts for s/p/o (auto-converts strings) + s_term = build_term(s) + p_term = build_term(p) + o_term = build_term(o) + + if s_term is not None: + request["s"] = s_term + if p_term is not None: + request["p"] = p_term + if o_term is not None: + request["o"] = o_term + if g is not None: + request["g"] = g if user is not None: request["user"] = user if collection is not None: request["collection"] = collection request.update(kwargs) - return self.client._send_request_sync("triples", self.flow_id, request, False) + result = self.client._send_request_sync("triples", self.flow_id, request, False) + # Return the triples list from the response + if isinstance(result, dict) and "response" in result: + return result["response"] + return result def triples_query_stream( self, - s: Optional[str] = None, - p: Optional[str] = None, - o: Optional[str] = None, + s: Optional[Union[str, Dict[str, Any]]] = None, + p: Optional[Union[str, Dict[str, Any]]] = None, + o: Optional[Union[str, Dict[str, Any]]] = None, + g: Optional[str] = None, user: Optional[str] = None, collection: Optional[str] = None, limit: int = 100, @@ -900,9 +1237,10 @@ class SocketFlowInstance: and memory overhead for large result sets. Args: - s: Subject URI (optional, use None for wildcard) - p: Predicate URI (optional, use None for wildcard) - o: Object URI or Literal (optional, use None for wildcard) + s: Subject filter - URI string, Term dict, or None for wildcard + p: Predicate filter - URI string, Term dict, or None for wildcard + o: Object filter - URI/literal string, Term dict, or None for wildcard + g: Named graph filter - URI string or None for all graphs user: User/keyspace identifier (optional) collection: Collection identifier (optional) limit: Maximum results to return (default: 100) @@ -930,12 +1268,20 @@ class SocketFlowInstance: "streaming": True, "batch-size": batch_size, } - if s is not None: - request["s"] = str(s) - if p is not None: - request["p"] = str(p) - if o is not None: - request["o"] = str(o) + + # Build Term dicts for s/p/o (auto-converts strings) + s_term = build_term(s) + p_term = build_term(p) + o_term = build_term(o) + + if s_term is not None: + request["s"] = s_term + if p_term is not None: + request["p"] = p_term + if o_term is not None: + request["o"] = o_term + if g is not None: + request["g"] = g if user is not None: request["user"] = user if collection is not None: diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index f66f7b82..d39310f2 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -212,19 +212,21 @@ class ProvenanceEvent: Each event represents a provenance node created during query processing. Attributes: - provenance_id: URI of the provenance node (e.g., urn:trustgraph:session:abc123) - event_type: Type of provenance event (session, retrieval, selection, answer) + explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123) + explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval) + event_type: Type of provenance event (question, exploration, focus, synthesis) """ - provenance_id: str - event_type: str = "" # Derived from provenance_id (session, retrieval, selection, answer) + explain_id: str + explain_graph: str = "" + event_type: str = "" # Derived from explain_id def __post_init__(self): - # Extract event type from provenance_id - if "session" in self.provenance_id: - self.event_type = "session" - elif "retrieval" in self.provenance_id: - self.event_type = "retrieval" - elif "selection" in self.provenance_id: - self.event_type = "selection" - elif "answer" in self.provenance_id: - self.event_type = "answer" + # Extract event type from explain_id + if "question" in self.explain_id: + self.event_type = "question" + elif "exploration" in self.explain_id: + self.event_type = "exploration" + elif "focus" in self.explain_id: + self.event_type = "focus" + elif "synthesis" in self.explain_id: + self.event_type = "synthesis" diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 4da0aec6..378bdb41 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -59,6 +59,15 @@ class AgentResponseTranslator(MessageTranslator): result["end_of_message"] = getattr(obj, "end_of_message", False) result["end_of_dialog"] = getattr(obj, "end_of_dialog", False) + # Include explainability fields if present + explain_id = getattr(obj, "explain_id", None) + if explain_id: + result["explain_id"] = explain_id + + explain_graph = getattr(obj, "explain_graph", None) + if explain_graph is not None: + result["explain_graph"] = explain_graph + # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "code": obj.error.code} diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 9f102f9a..7326b722 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -34,7 +34,12 @@ class DocumentRagResponseTranslator(MessageTranslator): def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: result = {} - # Include response content (even if empty string) + # Include message_type for distinguishing chunk vs explain messages + message_type = getattr(obj, "message_type", "") + if message_type: + result["message_type"] = message_type + + # Include response content for chunk messages if obj.response is not None: result["response"] = obj.response @@ -48,9 +53,12 @@ class DocumentRagResponseTranslator(MessageTranslator): if explain_graph is not None: result["explain_graph"] = explain_graph - # Include end_of_stream flag + # Include end_of_stream flag (LLM stream complete) result["end_of_stream"] = getattr(obj, "end_of_stream", False) + # Include end_of_session flag (entire session complete) + result["end_of_session"] = getattr(obj, "end_of_session", False) + # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "type": obj.error.type} @@ -59,7 +67,8 @@ class DocumentRagResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - is_final = getattr(obj, 'end_of_stream', False) + # Session is complete when end_of_session is True + is_final = getattr(obj, 'end_of_session', False) return self.from_pulsar(obj), is_final diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index c1cb522a..b22f44d8 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -82,6 +82,10 @@ from . namespaces import ( TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, # Agent provenance predicates TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, + # Agent document references + TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, + # Document reference predicate + TG_DOCUMENT, # Named graphs GRAPH_DEFAULT, GRAPH_SOURCE, GRAPH_RETRIEVAL, ) @@ -165,6 +169,10 @@ __all__ = [ "TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION", # Agent provenance predicates "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER", + # Agent document references + "TG_THOUGHT_DOCUMENT", "TG_OBSERVATION_DOCUMENT", + # Document reference predicate + "TG_DOCUMENT", # Named graphs "GRAPH_DEFAULT", "GRAPH_SOURCE", "GRAPH_RETRIEVAL", # Triple builders diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py index 1f108795..e0ee9841 100644 --- a/trustgraph-base/trustgraph/provenance/agent.py +++ b/trustgraph-base/trustgraph/provenance/agent.py @@ -17,7 +17,8 @@ from . namespaces import ( RDF_TYPE, RDFS_LABEL, PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME, TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, - TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, + TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, + TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, TG_AGENT_QUESTION, ) @@ -73,10 +74,12 @@ def agent_session_triples( def agent_iteration_triples( iteration_uri: str, parent_uri: str, - thought: str, - action: str, - arguments: Dict[str, Any], - observation: str, + thought: str = "", + action: str = "", + arguments: Dict[str, Any] = None, + observation: str = "", + thought_document_id: Optional[str] = None, + observation_document_id: Optional[str] = None, ) -> List[Triple]: """ Build triples for one agent iteration (Analysis - think/act/observe cycle). @@ -85,36 +88,53 @@ def agent_iteration_triples( - Entity declaration with tg:Analysis type - wasDerivedFrom link to parent (previous iteration or session) - Thought, action, arguments, and observation data + - Document references for thought/observation when stored in librarian Args: iteration_uri: URI of this iteration (from agent_iteration_uri) parent_uri: URI of the parent (previous iteration or session) - thought: The agent's reasoning/thought + thought: The agent's reasoning/thought (used if thought_document_id not provided) action: The tool/action name arguments: Arguments passed to the tool (will be JSON-encoded) - observation: The result/observation from the tool + observation: The result/observation from the tool (used if observation_document_id not provided) + thought_document_id: Optional document URI for thought in librarian (preferred) + observation_document_id: Optional document URI for observation in librarian (preferred) Returns: List of Triple objects """ + if arguments is None: + arguments = {} + triples = [ _triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)), _triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")), _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)), - _triple(iteration_uri, TG_THOUGHT, _literal(thought)), _triple(iteration_uri, TG_ACTION, _literal(action)), _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), - _triple(iteration_uri, TG_OBSERVATION, _literal(observation)), ] + # Thought: use document reference or inline + if thought_document_id: + triples.append(_triple(iteration_uri, TG_THOUGHT_DOCUMENT, _iri(thought_document_id))) + elif thought: + triples.append(_triple(iteration_uri, TG_THOUGHT, _literal(thought))) + + # Observation: use document reference or inline + if observation_document_id: + triples.append(_triple(iteration_uri, TG_OBSERVATION_DOCUMENT, _iri(observation_document_id))) + elif observation: + triples.append(_triple(iteration_uri, TG_OBSERVATION, _literal(observation))) + return triples def agent_final_triples( final_uri: str, parent_uri: str, - answer: str, + answer: str = "", + document_id: Optional[str] = None, ) -> List[Triple]: """ Build triples for an agent final answer (Conclusion). @@ -122,20 +142,29 @@ def agent_final_triples( Creates: - Entity declaration with tg:Conclusion type - wasDerivedFrom link to parent (last iteration or session) - - The answer text + - Either document reference (if document_id provided) or inline answer Args: final_uri: URI of the final answer (from agent_final_uri) parent_uri: URI of the parent (last iteration or session if no iterations) - answer: The final answer text + answer: The final answer text (used if document_id not provided) + document_id: Optional document URI in librarian (preferred) Returns: List of Triple objects """ - return [ + triples = [ _triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)), _triple(final_uri, RDFS_LABEL, _literal("Conclusion")), _triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)), - _triple(final_uri, TG_ANSWER, _literal(answer)), ] + + if document_id: + # Store reference to document in librarian (as IRI) + triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id))) + elif answer: + # Fallback: store inline answer + triples.append(_triple(final_uri, TG_ANSWER, _literal(answer))) + + return triples diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 15f1b7d3..91e82eac 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -92,6 +92,10 @@ TG_ARGUMENTS = TG + "arguments" TG_OBSERVATION = TG + "observation" TG_ANSWER = TG + "answer" +# Agent document references (for librarian storage) +TG_THOUGHT_DOCUMENT = TG + "thoughtDocument" +TG_OBSERVATION_DOCUMENT = TG + "observationDocument" + # Named graph URIs for RDF datasets # These separate different types of data while keeping them in the same collection GRAPH_DEFAULT = "" # Core knowledge facts (triples extracted from documents) diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index 35a387fc..91179047 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -30,11 +30,15 @@ class AgentRequest: @dataclass class AgentResponse: # Streaming-first design - chunk_type: str = "" # "thought", "action", "observation", "answer", "error" + chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error" content: str = "" # The actual content (interpretation depends on chunk_type) end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete end_of_dialog: bool = False # Entire agent dialog is complete + # Explainability fields + explain_id: str | None = None # Provenance URI (announced as created) + explain_graph: str | None = None # Named graph where explain was stored + # Legacy fields (deprecated but kept for backward compatibility) answer: str = "" error: Error | None = None diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 5b09b11e..0d0b79b8 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -43,6 +43,8 @@ class DocumentRagQuery: class DocumentRagResponse: error: Error | None = None response: str | None = "" - end_of_stream: bool = False + end_of_stream: bool = False # LLM response stream complete explain_id: str | None = None # Single explain URI (announced as created) explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval) + message_type: str = "" # "chunk" or "explain" + end_of_session: bool = False # Entire session complete diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 369fcdd4..03b71e2a 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -4,8 +4,19 @@ Uses the agent service to answer a question import argparse import os +import sys import textwrap -from trustgraph.api import Api +from trustgraph.api import ( + Api, + ExplainabilityClient, + ProvenanceEvent, + Question, + Analysis, + Conclusion, + AgentThought, + AgentObservation, + AgentAnswer, +) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -97,11 +108,148 @@ def output(text, prefix="> ", width=78): ) print(out) +def question_explainable( + url, question_text, flow_id, user, collection, + state=None, group=None, verbose=False, token=None, debug=False +): + """Execute agent with explainability - shows provenance events inline.""" + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) + + try: + # Track last chunk type for formatting + last_chunk_type = None + current_outputter = None + + # Stream agent with explainability - process events as they arrive + for item in flow.agent_explain( + question=question_text, + user=user, + collection=collection, + state=state, + group=group, + ): + if isinstance(item, AgentThought): + if last_chunk_type != "thought": + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + print() # Blank line between message types + if verbose: + current_outputter = Outputter(width=78, prefix="\U0001f914 ") + current_outputter.__enter__() + last_chunk_type = "thought" + if current_outputter: + current_outputter.output(item.content) + if current_outputter.word_buffer: + print(current_outputter.word_buffer, end="", flush=True) + current_outputter.column += len(current_outputter.word_buffer) + current_outputter.word_buffer = "" + + elif isinstance(item, AgentObservation): + if last_chunk_type != "observation": + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + print() + if verbose: + current_outputter = Outputter(width=78, prefix="\U0001f4a1 ") + current_outputter.__enter__() + last_chunk_type = "observation" + if current_outputter: + current_outputter.output(item.content) + if current_outputter.word_buffer: + print(current_outputter.word_buffer, end="", flush=True) + current_outputter.column += len(current_outputter.word_buffer) + current_outputter.word_buffer = "" + + elif isinstance(item, AgentAnswer): + if last_chunk_type != "answer": + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + print() + last_chunk_type = "answer" + # Print answer content directly + print(item.content, end="", flush=True) + + elif isinstance(item, ProvenanceEvent): + # Process provenance event immediately + prov_id = item.explain_id + explain_graph = item.explain_graph or "urn:graph:retrieval" + + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) + + if entity is None: + if debug: + print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr) + continue + + # Display based on entity type + if isinstance(entity, Question): + print(f"\n [session] {prov_id}", file=sys.stderr) + if entity.query: + print(f" Query: {entity.query}", file=sys.stderr) + if entity.timestamp: + print(f" Time: {entity.timestamp}", file=sys.stderr) + + elif isinstance(entity, Analysis): + print(f"\n [iteration] {prov_id}", file=sys.stderr) + if entity.thought: + thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought + print(f" Thought: {thought_short}", file=sys.stderr) + if entity.action: + print(f" Action: {entity.action}", file=sys.stderr) + + elif isinstance(entity, Conclusion): + print(f"\n [conclusion] {prov_id}", file=sys.stderr) + if entity.answer: + print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr) + + else: + if debug: + print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr) + + # Close any remaining outputter + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + + # Final newline if we ended with answer + if last_chunk_type == "answer": + print() + + finally: + socket.close() + + def question( url, question, flow_id, user, collection, plan=None, state=None, group=None, verbose=False, streaming=True, - token=None + token=None, explainable=False, debug=False ): + # Explainable mode uses the API to capture and process provenance events + if explainable: + question_explainable( + url=url, + question_text=question, + flow_id=flow_id, + user=user, + collection=collection, + state=state, + group=group, + verbose=verbose, + token=token, + debug=debug + ) + return if verbose: output(wrap(question), "\U00002753 ") @@ -270,6 +418,18 @@ def main(): help=f'Disable streaming (use legacy mode)' ) + parser.add_argument( + '-x', '--explainable', + action='store_true', + help='Show provenance events: Session, Iterations, Conclusion (implies streaming)' + ) + + parser.add_argument( + '--debug', + action='store_true', + help='Show debug output for troubleshooting' + ) + args = parser.parse_args() try: @@ -286,6 +446,8 @@ def main(): verbose = args.verbose, streaming = not args.no_streaming, token = args.token, + explainable = args.explainable, + debug = args.debug, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 7e88bdc4..4ed7bca9 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -4,7 +4,16 @@ Uses the DocumentRAG service to answer a question import argparse import os -from trustgraph.api import Api +import sys +from trustgraph.api import ( + Api, + ExplainabilityClient, + RAGChunk, + ProvenanceEvent, + Question, + Exploration, + Synthesis, +) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -12,7 +21,90 @@ default_user = 'trustgraph' default_collection = 'default' default_doc_limit = 10 -def question(url, flow_id, question, user, collection, doc_limit, streaming=True, token=None): + +def question_explainable( + url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False +): + """Execute document RAG with explainability - shows provenance events inline.""" + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) + + try: + # Stream DocumentRAG with explainability - process events as they arrive + for item in flow.document_rag_explain( + query=question_text, + user=user, + collection=collection, + doc_limit=doc_limit, + ): + if isinstance(item, RAGChunk): + # Print response content + print(item.content, end="", flush=True) + + elif isinstance(item, ProvenanceEvent): + # Process provenance event immediately + prov_id = item.explain_id + explain_graph = item.explain_graph or "urn:graph:retrieval" + + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) + + if entity is None: + if debug: + print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr) + continue + + # Display based on entity type + if isinstance(entity, Question): + print(f"\n [question] {prov_id}", file=sys.stderr) + if entity.query: + print(f" Query: {entity.query}", file=sys.stderr) + if entity.timestamp: + print(f" Time: {entity.timestamp}", file=sys.stderr) + + elif isinstance(entity, Exploration): + print(f"\n [exploration] {prov_id}", file=sys.stderr) + if entity.chunk_count: + print(f" Chunks retrieved: {entity.chunk_count}", file=sys.stderr) + + elif isinstance(entity, Synthesis): + print(f"\n [synthesis] {prov_id}", file=sys.stderr) + if entity.content: + print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr) + + else: + if debug: + print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr) + + print() # Final newline + + finally: + socket.close() + + +def question( + url, flow_id, question_text, user, collection, doc_limit, + streaming=True, token=None, explainable=False, debug=False +): + # Explainable mode uses the API to capture and process provenance events + if explainable: + question_explainable( + url=url, + flow_id=flow_id, + question_text=question_text, + user=user, + collection=collection, + doc_limit=doc_limit, + token=token, + debug=debug + ) + return # Create API client api = Api(url=url, token=token) @@ -24,7 +116,7 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True try: response = flow.document_rag( - query=question, + query=question_text, user=user, collection=collection, doc_limit=doc_limit, @@ -42,13 +134,14 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True # Use REST API for non-streaming flow = api.flow().id(flow_id) resp = flow.document_rag( - query=question, + query=question_text, user=user, collection=collection, doc_limit=doc_limit, ) print(resp) + def main(): parser = argparse.ArgumentParser( @@ -105,6 +198,18 @@ def main(): help='Disable streaming (use non-streaming mode)' ) + parser.add_argument( + '-x', '--explainable', + action='store_true', + help='Show provenance events: Question, Exploration, Synthesis (implies streaming)' + ) + + parser.add_argument( + '--debug', + action='store_true', + help='Show debug output for troubleshooting' + ) + args = parser.parse_args() try: @@ -112,12 +217,14 @@ def main(): question( url=args.url, flow_id=args.flow_id, - question=args.question, + question_text=args.question, user=args.user, collection=args.collection, doc_limit=args.doc_limit, streaming=not args.no_streaming, token=args.token, + explainable=args.explainable, + debug=args.debug, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index db41f631..27c6854d 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -8,7 +8,16 @@ import os import sys import websockets import asyncio -from trustgraph.api import Api +from trustgraph.api import ( + Api, + ExplainabilityClient, + RAGChunk, + ProvenanceEvent, + Question, + Exploration, + Focus, + Synthesis, +) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -602,18 +611,111 @@ async def _question_explainable( print() # Final newline +def _question_explainable_api( + url, flow_id, question_text, user, collection, entity_limit, triple_limit, + max_subgraph_size, max_path_length, token=None, debug=False +): + """Execute graph RAG with explainability using the new API classes.""" + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) + + try: + # Stream GraphRAG with explainability - process events as they arrive + for item in flow.graph_rag_explain( + query=question_text, + user=user, + collection=collection, + max_subgraph_size=max_subgraph_size, + max_subgraph_count=5, + max_entity_distance=max_path_length, + ): + if isinstance(item, RAGChunk): + # Print response content + print(item.content, end="", flush=True) + + elif isinstance(item, ProvenanceEvent): + # Process provenance event immediately + prov_id = item.explain_id + explain_graph = item.explain_graph or "urn:graph:retrieval" + + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) + + if entity is None: + if debug: + print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr) + continue + + # Display based on entity type + if isinstance(entity, Question): + print(f"\n [question] {prov_id}", file=sys.stderr) + if entity.query: + print(f" Query: {entity.query}", file=sys.stderr) + if entity.timestamp: + print(f" Time: {entity.timestamp}", file=sys.stderr) + + elif isinstance(entity, Exploration): + print(f"\n [exploration] {prov_id}", file=sys.stderr) + if entity.edge_count: + print(f" Edges explored: {entity.edge_count}", file=sys.stderr) + + elif isinstance(entity, Focus): + print(f"\n [focus] {prov_id}", file=sys.stderr) + if entity.selected_edge_uris: + print(f" Focused on {len(entity.selected_edge_uris)} edge(s)", file=sys.stderr) + + # Fetch full focus with edge details + focus_full = explain_client.fetch_focus_with_edges( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) + if focus_full and focus_full.edge_selections: + for edge_sel in focus_full.edge_selections: + if edge_sel.edge: + # Resolve labels for edge components + s_label, p_label, o_label = explain_client.resolve_edge_labels( + edge_sel.edge, user, collection + ) + print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) + if edge_sel.reasoning: + r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning + print(f" Reason: {r_short}", file=sys.stderr) + + elif isinstance(entity, Synthesis): + print(f"\n [synthesis] {prov_id}", file=sys.stderr) + if entity.content: + print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr) + + else: + if debug: + print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr) + + print() # Final newline + + finally: + socket.close() + + def question( url, flow_id, question, user, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, streaming=True, token=None, explainable=False, debug=False ): - # Explainable mode uses direct websocket to capture provenance events + # Explainable mode uses the API to capture and process provenance events if explainable: - asyncio.run(_question_explainable( + _question_explainable_api( url=url, flow_id=flow_id, - question=question, + question_text=question, user=user, collection=collection, entity_limit=entity_limit, @@ -622,7 +724,7 @@ def question( max_path_length=max_path_length, token=token, debug=debug - )) + ) return # Create API client diff --git a/trustgraph-cli/trustgraph/cli/list_explain_traces.py b/trustgraph-cli/trustgraph/cli/list_explain_traces.py index d2bb28ea..f545c53f 100644 --- a/trustgraph-cli/trustgraph/cli/list_explain_traces.py +++ b/trustgraph-cli/trustgraph/cli/list_explain_traces.py @@ -14,180 +14,17 @@ import json import os import sys from tabulate import tabulate -from trustgraph.api import Api +from trustgraph.api import Api, ExplainabilityClient default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = 'trustgraph' default_collection = 'default' -# Predicates -TG = "https://trustgraph.ai/ns/" -TG_QUERY = TG + "query" -TG_QUESTION = TG + "Question" -TG_ANALYSIS = TG + "Analysis" -TG_EXPLORATION = TG + "Exploration" -PROV = "http://www.w3.org/ns/prov#" -PROV_STARTED_AT_TIME = PROV + "startedAtTime" -PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy" -RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" - # Retrieval graph RETRIEVAL_GRAPH = "urn:graph:retrieval" -def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000): - """Query triples using the socket API.""" - request = { - "user": user, - "collection": collection, - "limit": limit, - "streaming": False, - } - - if s is not None: - request["s"] = {"t": "i", "i": s} - if p is not None: - request["p"] = {"t": "i", "i": p} - if o is not None: - if isinstance(o, str): - if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"): - request["o"] = {"t": "i", "i": o} - else: - request["o"] = {"t": "l", "v": o} - elif isinstance(o, dict): - request["o"] = o - if g is not None: - request["g"] = g - - triples = [] - try: - for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True): - if isinstance(response, dict): - triple_list = response.get("response", response.get("triples", [])) - else: - triple_list = response - - if not isinstance(triple_list, list): - triple_list = [triple_list] if triple_list else [] - - for t in triple_list: - s_val = extract_value(t.get("s", {})) - p_val = extract_value(t.get("p", {})) - o_val = extract_value(t.get("o", {})) - triples.append((s_val, p_val, o_val)) - except Exception as e: - print(f"Error querying triples: {e}", file=sys.stderr) - - return triples - - -def extract_value(term): - """Extract value from a term dict.""" - if not term: - return "" - - t = term.get("t") or term.get("type") - - if t == "i": - return term.get("i") or term.get("iri", "") - elif t == "l": - return term.get("v") or term.get("value", "") - elif t == "t": - # Quoted triple - tr = term.get("tr") or term.get("triple", {}) - return { - "s": extract_value(tr.get("s", {})), - "p": extract_value(tr.get("p", {})), - "o": extract_value(tr.get("o", {})), - } - - # Fallback for raw values - if "i" in term: - return term["i"] - if "v" in term: - return term["v"] - - return str(term) - - -def get_timestamp(socket, flow_id, user, collection, question_id): - """Get timestamp for a question.""" - triples = query_triples( - socket, flow_id, user, collection, - s=question_id, p=PROV_STARTED_AT_TIME, g=RETRIEVAL_GRAPH - ) - for s, p, o in triples: - return o - return "" - - -def get_session_type(socket, flow_id, user, collection, session_id): - """ - Get the type of session (Agent or GraphRAG). - - Both have tg:Question type, so we distinguish by URI pattern - or by checking what's derived from it. - """ - # Fast path: check URI pattern - if session_id.startswith("urn:trustgraph:agent:"): - return "Agent" - if session_id.startswith("urn:trustgraph:question:"): - return "GraphRAG" - - # Check what's derived from this entity - derived = query_triples( - socket, flow_id, user, collection, - p=PROV_WAS_DERIVED_FROM, o=session_id, g=RETRIEVAL_GRAPH - ) - generated = query_triples( - socket, flow_id, user, collection, - p=PROV_WAS_GENERATED_BY, o=session_id, g=RETRIEVAL_GRAPH - ) - - for s, p, o in derived + generated: - child_types = query_triples( - socket, flow_id, user, collection, - s=s, p=RDF_TYPE, g=RETRIEVAL_GRAPH - ) - for _, _, child_type in child_types: - if child_type == TG_ANALYSIS: - return "Agent" - if child_type == TG_EXPLORATION: - return "GraphRAG" - - return "GraphRAG" - - -def list_sessions(socket, flow_id, user, collection, limit): - """List all explainability sessions (GraphRAG and Agent) by finding questions.""" - # Query for all triples with predicate = tg:query - triples = query_triples( - socket, flow_id, user, collection, - p=TG_QUERY, g=RETRIEVAL_GRAPH, limit=limit - ) - - sessions = [] - for question_id, _, query_text in triples: - # Get timestamp if available - timestamp = get_timestamp(socket, flow_id, user, collection, question_id) - # Get session type (Agent or GraphRAG) - session_type = get_session_type(socket, flow_id, user, collection, question_id) - - sessions.append({ - "id": question_id, - "type": session_type, - "question": query_text, - "time": timestamp, - }) - - # Sort by timestamp (newest first) if available - sessions.sort(key=lambda x: x.get("time", ""), reverse=True) - - return sessions - - def truncate_text(text, max_len=60): """Truncate text to max length with ellipsis.""" if not text: @@ -277,16 +114,42 @@ def main(): try: api = Api(args.api_url, token=args.token) socket = api.socket() + flow = socket.flow(args.flow_id) + explain_client = ExplainabilityClient(flow) try: - sessions = list_sessions( - socket=socket, - flow_id=args.flow_id, + # List all sessions using the API + questions = explain_client.list_sessions( + graph=RETRIEVAL_GRAPH, user=args.user, collection=args.collection, limit=args.limit, ) + # Convert to output format + sessions = [] + for q in questions: + session_type = explain_client.detect_session_type( + q.uri, + graph=RETRIEVAL_GRAPH, + user=args.user, + collection=args.collection + ) + + # Map type names + type_display = { + "graphrag": "GraphRAG", + "docrag": "DocRAG", + "agent": "Agent", + }.get(session_type, session_type.title()) + + sessions.append({ + "id": q.uri, + "type": type_display, + "question": q.query, + "time": q.timestamp, + }) + if args.format == 'json': print_json(sessions) else: diff --git a/trustgraph-cli/trustgraph/cli/query_graph.py b/trustgraph-cli/trustgraph/cli/query_graph.py index a123e632..a2c38353 100644 --- a/trustgraph-cli/trustgraph/cli/query_graph.py +++ b/trustgraph-cli/trustgraph/cli/query_graph.py @@ -291,42 +291,25 @@ def query_graph( ): """Query the triple store with pattern matching. - Uses the WebSocket API's raw streaming mode for efficient delivery of results. + Uses the API's triples_query_stream for efficient streaming delivery. """ socket = Api(url, token=token).socket() - - # Build request dict directly (bypassing triples_query_stream's string conversion) - request = { - "user": user, - "collection": collection, - "limit": limit, - "streaming": True, - "batch-size": batch_size, - } - - # Add term dicts for s/p/o (None means wildcard) - if subject is not None: - request["s"] = subject - if predicate is not None: - request["p"] = predicate - if obj is not None: - request["o"] = obj - if graph is not None: - request["g"] = graph + flow = socket.flow(flow_id) all_triples = [] try: - # Use raw streaming mode - yields response dicts directly - for response in socket._send_request_sync( - "triples", flow_id, request, streaming_raw=True + # Use triples_query_stream - accepts Term dicts directly + for triples in flow.triples_query_stream( + s=subject, + p=predicate, + o=obj, + g=graph, + user=user, + collection=collection, + limit=limit, + batch_size=batch_size, ): - # Response may have triples in different locations depending on format - if isinstance(response, dict): - triples = response.get("response", response.get("triples", [])) - else: - triples = response - if not isinstance(triples, list): triples = [triples] if triples else [] diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index d09b220c..b3fc7058 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -18,228 +18,99 @@ import argparse import json import os import sys -from trustgraph.api import Api +from trustgraph.api import ( + Api, + ExplainabilityClient, + Question, + Exploration, + Focus, + Synthesis, + Analysis, + Conclusion, +) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = 'trustgraph' default_collection = 'default' -# Predicates -TG = "https://trustgraph.ai/ns/" -TG_QUERY = TG + "query" -TG_EDGE_COUNT = TG + "edgeCount" -TG_SELECTED_EDGE = TG + "selectedEdge" -TG_EDGE = TG + "edge" -TG_REASONING = TG + "reasoning" -TG_CONTENT = TG + "content" -TG_DOCUMENT = TG + "document" -TG_REIFIES = TG + "reifies" -# Explainability entity types -TG_QUESTION = TG + "Question" -TG_EXPLORATION = TG + "Exploration" -TG_FOCUS = TG + "Focus" -TG_SYNTHESIS = TG + "Synthesis" -TG_ANALYSIS = TG + "Analysis" -TG_CONCLUSION = TG + "Conclusion" - -# Agent predicates -TG_THOUGHT = TG + "thought" -TG_ACTION = TG + "action" -TG_ARGUMENTS = TG + "arguments" -TG_OBSERVATION = TG + "observation" -TG_ANSWER = TG + "answer" -PROV = "http://www.w3.org/ns/prov#" -PROV_STARTED_AT_TIME = PROV + "startedAtTime" -PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy" -RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" -RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" - # Graphs RETRIEVAL_GRAPH = "urn:graph:retrieval" SOURCE_GRAPH = "urn:graph:source" - -def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000): - """Query triples using the socket API.""" - request = { - "user": user, - "collection": collection, - "limit": limit, - "streaming": False, - } - - if s is not None: - request["s"] = {"t": "i", "i": s} - if p is not None: - request["p"] = {"t": "i", "i": p} - if o is not None: - if isinstance(o, str): - if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"): - request["o"] = {"t": "i", "i": o} - else: - request["o"] = {"t": "l", "v": o} - elif isinstance(o, dict): - request["o"] = o - if g is not None: - request["g"] = g - - triples = [] - try: - for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True): - if isinstance(response, dict): - triple_list = response.get("response", response.get("triples", [])) - else: - triple_list = response - - if not isinstance(triple_list, list): - triple_list = [triple_list] if triple_list else [] - - for t in triple_list: - s_val = extract_value(t.get("s", {})) - p_val = extract_value(t.get("p", {})) - o_val = extract_value(t.get("o", {})) - triples.append((s_val, p_val, o_val)) - except Exception as e: - print(f"Error querying triples: {e}", file=sys.stderr) - - return triples +# Provenance predicates for edge tracing +TG = "https://trustgraph.ai/ns/" +TG_REIFIES = TG + "reifies" +PROV = "http://www.w3.org/ns/prov#" +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -def extract_value(term): - """Extract value from a term dict.""" - if not term: - return "" +def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client): + """ + Trace an edge back to its source document via reification. - t = term.get("t") or term.get("type") + Args: + flow: SocketFlowInstance + user: User identifier + collection: Collection identifier + edge: Dict with s, p, o keys + label_cache: Dict for caching labels + explain_client: ExplainabilityClient for label resolution - if t == "i": - return term.get("i") or term.get("iri", "") - elif t == "l": - return term.get("v") or term.get("value", "") - elif t == "t": - # Quoted triple - tr = term.get("tr") or term.get("triple", {}) - return { - "s": extract_value(tr.get("s", {})), - "p": extract_value(tr.get("p", {})), - "o": extract_value(tr.get("o", {})), - } + Returns: + List of provenance chains, each chain is list of {uri, label} + """ + edge_s = edge.get("s", "") + edge_p = edge.get("p", "") + edge_o = edge.get("o", "") - # Fallback for raw values - if "i" in term: - return term["i"] - if "v" in term: - return term["v"] + # Build quoted triple for lookup + def build_term(val): + if isinstance(val, str) and (val.startswith("http") or val.startswith("urn:")): + return {"t": "i", "i": val} + return {"t": "l", "v": str(val)} - return str(term) - - -def get_node_properties(socket, flow_id, user, collection, node_uri, graph=RETRIEVAL_GRAPH): - """Get all properties of a node as a dict.""" - triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=graph) - props = {} - for s, p, o in triples: - if p not in props: - props[p] = [] - props[p].append(o) - return props - - -def find_by_predicate_object(socket, flow_id, user, collection, predicate, obj, graph=RETRIEVAL_GRAPH): - """Find subjects where predicate = obj.""" - triples = query_triples(socket, flow_id, user, collection, p=predicate, o=obj, g=graph) - return [s for s, p, o in triples] - - -def get_label(socket, flow_id, user, collection, uri, label_cache): - """Get label for a URI, with caching.""" - if not isinstance(uri, str) or not (uri.startswith("http://") or uri.startswith("https://") or uri.startswith("urn:")): - return uri - - if uri in label_cache: - return label_cache[uri] - - triples = query_triples(socket, flow_id, user, collection, s=uri, p=RDFS_LABEL) - for s, p, o in triples: - label_cache[uri] = o - return o - - label_cache[uri] = uri - return uri - - -def get_document_content(api, user, doc_id, max_content): - """Fetch document content from librarian API.""" - try: - library = api.library() - content = library.get_document_content(user=user, id=doc_id) - - # Try to decode as text - try: - text = content.decode('utf-8') - if len(text) > max_content: - return text[:max_content] + "... [truncated]" - return text - except UnicodeDecodeError: - return f"[Binary: {len(content)} bytes]" - except Exception as e: - return f"[Error fetching content: {e}]" - - -def trace_edge_provenance(socket, flow_id, user, collection, edge_s, edge_p, edge_o, label_cache): - """Trace an edge back to its source document via reification.""" - # Build the quoted triple for lookup quoted_triple = { "t": "t", "tr": { - "s": {"t": "i", "i": edge_s} if isinstance(edge_s, str) and (edge_s.startswith("http") or edge_s.startswith("urn:")) else {"t": "l", "v": edge_s}, - "p": {"t": "i", "i": edge_p}, - "o": {"t": "i", "i": edge_o} if isinstance(edge_o, str) and (edge_o.startswith("http") or edge_o.startswith("urn:")) else {"t": "l", "v": edge_o}, + "s": build_term(edge_s), + "p": build_term(edge_p), + "o": build_term(edge_o), } } # Query: ?stmt tg:reifies <> - request = { - "user": user, - "collection": collection, - "limit": 10, - "streaming": False, - "p": {"t": "i", "i": TG_REIFIES}, - "o": quoted_triple, - "g": SOURCE_GRAPH, - } - - stmt_uris = [] try: - for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True): - if isinstance(response, dict): - triple_list = response.get("response", response.get("triples", [])) - else: - triple_list = response - - if not isinstance(triple_list, list): - triple_list = [triple_list] if triple_list else [] - - for t in triple_list: - s_val = extract_value(t.get("s", {})) - if s_val: - stmt_uris.append(s_val) + results = flow.triples_query( + p=TG_REIFIES, + o=quoted_triple, + g=SOURCE_GRAPH, + user=user, + collection=collection, + limit=10 + ) except Exception: - pass + return [] - # For each statement, find wasDerivedFrom chain + # Extract statement URIs + stmt_uris = [] + for t in results: + s_term = t.get("s", {}) + s_val = s_term.get("i") or s_term.get("v", "") + if s_val: + stmt_uris.append(s_val) + + # For each statement, trace wasDerivedFrom chain provenance_chains = [] for stmt_uri in stmt_uris: - chain = trace_provenance_chain(socket, flow_id, user, collection, stmt_uri, label_cache) + chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client) if chain: provenance_chains.append(chain) return provenance_chains -def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_cache, max_depth=10): +def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10): """Trace prov:wasDerivedFrom chain from start_uri to root.""" chain = [] current = start_uri @@ -248,17 +119,32 @@ def trace_provenance_chain(socket, flow_id, user, collection, start_uri, label_c if not current: break - label = get_label(socket, flow_id, user, collection, current, label_cache) + # Get label + if current in label_cache: + label = label_cache[current] + else: + label = explain_client.resolve_label(current, user, collection) + label_cache[current] = label + chain.append({"uri": current, "label": label}) - # Get parent - triples = query_triples( - socket, flow_id, user, collection, - s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH - ) + # Get parent via wasDerivedFrom + try: + results = flow.triples_query( + s=current, + p=PROV_WAS_DERIVED_FROM, + g=SOURCE_GRAPH, + user=user, + collection=collection, + limit=1 + ) + except Exception: + break + parent = None - for s, p, o in triples: - parent = o + for t in results: + o_term = t.get("o", {}) + parent = o_term.get("i") or o_term.get("v", "") break if not parent or parent == current: @@ -276,331 +162,24 @@ def format_provenance_chain(chain): return " -> ".join(labels) -def format_edge(edge, label_cache=None, socket=None, flow_id=None, user=None, collection=None): - """Format a quoted triple edge for display.""" - if not isinstance(edge, dict): - return str(edge) +def print_graphrag_text(trace, explain_client, flow, user, collection, show_provenance=False): + """Print GraphRAG trace in text format.""" + question = trace.get("question") - s = edge.get("s", "?") - p = edge.get("p", "?") - o = edge.get("o", "?") - - # Get labels if available - if label_cache and socket: - s_label = get_label(socket, flow_id, user, collection, s, label_cache) - p_label = get_label(socket, flow_id, user, collection, p, label_cache) - o_label = get_label(socket, flow_id, user, collection, o, label_cache) - else: - # Shorten URIs for display - s_label = s.split("/")[-1] if "/" in str(s) else s - p_label = p.split("/")[-1] if "/" in str(p) else p - o_label = o.split("/")[-1] if "/" in str(o) else o - - return f"({s_label}, {p_label}, {o_label})" - - -def detect_trace_type(socket, flow_id, user, collection, entity_id): - """ - Detect whether an entity is an agent Question or GraphRAG Question. - - Both have rdf:type = tg:Question, so we distinguish by checking - what's derived from it: - - Agent: has tg:Analysis or tg:Conclusion derived - - GraphRAG: has tg:Exploration derived - - Also checks URI pattern as fallback: - - urn:trustgraph:agent: -> agent - - urn:trustgraph:question: -> graphrag - - Returns: - "agent" or "graphrag" - """ - # Check URI pattern first (fast path) - if entity_id.startswith("urn:trustgraph:agent:"): - return "agent" - if entity_id.startswith("urn:trustgraph:question:"): - return "graphrag" - - # Check what's derived from this entity - derived = find_by_predicate_object( - socket, flow_id, user, collection, - PROV_WAS_DERIVED_FROM, entity_id - ) - - # Also check wasGeneratedBy (GraphRAG exploration uses this) - generated = find_by_predicate_object( - socket, flow_id, user, collection, - PROV_WAS_GENERATED_BY, entity_id - ) - - all_children = derived + generated - - for child_id in all_children: - child_types = query_triples( - socket, flow_id, user, collection, - s=child_id, p=RDF_TYPE, g=RETRIEVAL_GRAPH - ) - for s, p, o in child_types: - if o == TG_ANALYSIS or o == TG_CONCLUSION: - return "agent" - if o == TG_EXPLORATION: - return "graphrag" - - # Default to graphrag - return "graphrag" - - -def build_agent_trace(socket, flow_id, user, collection, session_id, api=None, max_answer=500): - """Build the full explainability trace for an agent session.""" - trace = { - "session_id": session_id, - "type": "agent", - "question": None, - "time": None, - "iterations": [], - "final_answer": None, - } - - # Get session metadata - props = get_node_properties(socket, flow_id, user, collection, session_id) - trace["question"] = props.get(TG_QUERY, [None])[0] - trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0] - - # Find all entities derived from this session (iterations and final) - # Start by looking for entities where prov:wasDerivedFrom = session_id - current_uri = session_id - iteration_num = 1 - - while True: - # Find entities derived from current - derived_ids = find_by_predicate_object( - socket, flow_id, user, collection, - PROV_WAS_DERIVED_FROM, current_uri - ) - - if not derived_ids: - break - - derived_id = derived_ids[0] - derived_props = get_node_properties(socket, flow_id, user, collection, derived_id) - - # Check type - types = derived_props.get(RDF_TYPE, []) - - if TG_ANALYSIS in types: - iteration = { - "id": derived_id, - "iteration_num": iteration_num, - "thought": derived_props.get(TG_THOUGHT, [None])[0], - "action": derived_props.get(TG_ACTION, [None])[0], - "arguments": derived_props.get(TG_ARGUMENTS, [None])[0], - "observation": derived_props.get(TG_OBSERVATION, [None])[0], - } - trace["iterations"].append(iteration) - current_uri = derived_id - iteration_num += 1 - - elif TG_CONCLUSION in types: - answer = derived_props.get(TG_ANSWER, [None])[0] - if answer and len(answer) > max_answer: - answer = answer[:max_answer] + "... [truncated]" - trace["final_answer"] = { - "id": derived_id, - "answer": answer, - } - break - - else: - # Unknown type, stop traversal - break - - return trace - - -def print_agent_text(trace): - """Print agent trace in text format.""" - print(f"=== Agent Session: {trace['session_id']} ===") + print(f"=== GraphRAG Session: {question.uri if question else 'Unknown'} ===") print() - if trace["question"]: - print(f"Question: {trace['question']}") - if trace["time"]: - print(f"Time: {trace['time']}") - print() - - # Analysis steps - print("--- Analysis ---") - iterations = trace.get("iterations", []) - if iterations: - for iteration in iterations: - print(f"Analysis {iteration['iteration_num']}:") - print(f" Thought: {iteration.get('thought', 'N/A')}") - print(f" Action: {iteration.get('action', 'N/A')}") - - args = iteration.get('arguments') - if args: - # Try to pretty-print JSON arguments - try: - import json - args_obj = json.loads(args) - args_str = json.dumps(args_obj, indent=4) - # Indent each line - args_lines = args_str.split('\n') - print(f" Arguments:") - for line in args_lines: - print(f" {line}") - except: - print(f" Arguments: {args}") - else: - print(f" Arguments: N/A") - - obs = iteration.get('observation', 'N/A') - if obs and len(obs) > 200: - obs = obs[:200] + "... [truncated]" - print(f" Observation: {obs}") - print() - else: - print("No analysis steps recorded") - print() - - # Conclusion - print("--- Conclusion ---") - final = trace.get("final_answer") - if final and final.get("answer"): - print("Answer:") - for line in final["answer"].split("\n"): - print(f" {line}") - else: - print("No conclusion recorded") - - -def print_agent_json(trace): - """Print agent trace as JSON.""" - print(json.dumps(trace, indent=2)) - - -def build_trace(socket, flow_id, user, collection, question_id, api=None, show_provenance=False, max_answer=500): - """Build the full explainability trace for a question.""" - label_cache = {} - - trace = { - "question_id": question_id, - "question": None, - "time": None, - "exploration": None, - "focus": None, - "synthesis": None, - } - - # Get question metadata - props = get_node_properties(socket, flow_id, user, collection, question_id) - trace["question"] = props.get(TG_QUERY, [None])[0] - trace["time"] = props.get(PROV_STARTED_AT_TIME, [None])[0] - - # Find exploration: ?exploration prov:wasGeneratedBy question_id - exploration_ids = find_by_predicate_object( - socket, flow_id, user, collection, - PROV_WAS_GENERATED_BY, question_id - ) - - if exploration_ids: - exploration_id = exploration_ids[0] - exploration_props = get_node_properties(socket, flow_id, user, collection, exploration_id) - trace["exploration"] = { - "id": exploration_id, - "edge_count": exploration_props.get(TG_EDGE_COUNT, [None])[0], - } - - # Find focus: ?focus prov:wasDerivedFrom exploration_id - focus_ids = find_by_predicate_object( - socket, flow_id, user, collection, - PROV_WAS_DERIVED_FROM, exploration_id - ) - - if focus_ids: - focus_id = focus_ids[0] - focus_props = get_node_properties(socket, flow_id, user, collection, focus_id) - - # Get selected edges - edge_selection_uris = focus_props.get(TG_SELECTED_EDGE, []) - selected_edges = [] - - for edge_sel_uri in edge_selection_uris: - edge_sel_props = get_node_properties(socket, flow_id, user, collection, edge_sel_uri) - edge = edge_sel_props.get(TG_EDGE, [None])[0] - reasoning = edge_sel_props.get(TG_REASONING, [None])[0] - - edge_info = { - "edge": edge, - "reasoning": reasoning, - } - - # Trace provenance if requested - if show_provenance and isinstance(edge, dict): - provenance = trace_edge_provenance( - socket, flow_id, user, collection, - edge.get("s", ""), edge.get("p", ""), edge.get("o", ""), - label_cache - ) - edge_info["provenance"] = provenance - - selected_edges.append(edge_info) - - trace["focus"] = { - "id": focus_id, - "selected_edges": selected_edges, - } - - # Find synthesis: ?synthesis prov:wasDerivedFrom focus_id - synthesis_ids = find_by_predicate_object( - socket, flow_id, user, collection, - PROV_WAS_DERIVED_FROM, focus_id - ) - - if synthesis_ids: - synthesis_id = synthesis_ids[0] - synthesis_props = get_node_properties(socket, flow_id, user, collection, synthesis_id) - - # Get content directly or via document reference - content = synthesis_props.get(TG_CONTENT, [None])[0] - doc_id = synthesis_props.get(TG_DOCUMENT, [None])[0] - - if not content and doc_id and api: - content = get_document_content(api, user, doc_id, max_answer) - elif content and len(content) > max_answer: - content = content[:max_answer] + "... [truncated]" - - trace["synthesis"] = { - "id": synthesis_id, - "document_id": doc_id, - "answer": content, - } - - # Store label cache for formatting - trace["_label_cache"] = label_cache - - return trace - - -def print_text(trace, show_provenance=False): - """Print trace in text format.""" - label_cache = trace.get("_label_cache", {}) - - print(f"=== GraphRAG Session: {trace['question_id']} ===") - print() - - if trace["question"]: - print(f"Question: {trace['question']}") - if trace["time"]: - print(f"Time: {trace['time']}") + if question: + print(f"Question: {question.query}") + if question.timestamp: + print(f"Time: {question.timestamp}") print() # Exploration print("--- Exploration ---") exploration = trace.get("exploration") if exploration: - edge_count = exploration.get("edge_count", "?") - print(f"Retrieved {edge_count} edges from knowledge graph") + print(f"Retrieved {exploration.edge_count} edges from knowledge graph") else: print("No exploration data found") print() @@ -609,24 +188,28 @@ def print_text(trace, show_provenance=False): print("--- Focus (Edge Selection) ---") focus = trace.get("focus") if focus: - edges = focus.get("selected_edges", []) + edges = focus.edge_selections print(f"Selected {len(edges)} edges:") print() - for i, edge_info in enumerate(edges, 1): - edge = edge_info.get("edge") - reasoning = edge_info.get("reasoning") + label_cache = {} - if edge: - edge_str = format_edge(edge) - print(f" {i}. {edge_str}") + for i, edge_sel in enumerate(edges, 1): + if edge_sel.edge: + s_label, p_label, o_label = explain_client.resolve_edge_labels( + edge_sel.edge, user, collection + ) + print(f" {i}. ({s_label}, {p_label}, {o_label})") - if reasoning: - r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning + if edge_sel.reasoning: + r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning print(f" Reasoning: {r_short}") - if show_provenance: - provenance = edge_info.get("provenance", []) + if show_provenance and edge_sel.edge: + provenance = trace_edge_provenance( + flow, user, collection, edge_sel.edge, + label_cache, explain_client + ) for chain in provenance: chain_str = format_provenance_chain(chain) if chain_str: @@ -641,11 +224,9 @@ def print_text(trace, show_provenance=False): print("--- Synthesis ---") synthesis = trace.get("synthesis") if synthesis: - answer = synthesis.get("answer") - if answer: + if synthesis.content: print("Answer:") - # Indent the answer - for line in answer.split("\n"): + for line in synthesis.content.split("\n"): print(f" {line}") else: print("No answer content found") @@ -653,11 +234,173 @@ def print_text(trace, show_provenance=False): print("No synthesis data found") -def print_json(trace): - """Print trace as JSON.""" - # Remove internal cache before printing - output = {k: v for k, v in trace.items() if not k.startswith("_")} - print(json.dumps(output, indent=2)) +def print_docrag_text(trace): + """Print DocRAG trace in text format.""" + question = trace.get("question") + + print(f"=== DocRAG Session: {question.uri if question else 'Unknown'} ===") + print() + + if question: + print(f"Question: {question.query}") + if question.timestamp: + print(f"Time: {question.timestamp}") + print() + + # Exploration + print("--- Exploration ---") + exploration = trace.get("exploration") + if exploration: + print(f"Retrieved {exploration.chunk_count} chunks from document store") + else: + print("No exploration data found") + print() + + # Synthesis (no Focus step for DocRAG) + print("--- Synthesis ---") + synthesis = trace.get("synthesis") + if synthesis: + if synthesis.content: + print("Answer:") + for line in synthesis.content.split("\n"): + print(f" {line}") + else: + print("No answer content found") + else: + print("No synthesis data found") + + +def print_agent_text(trace): + """Print Agent trace in text format.""" + question = trace.get("question") + + print(f"=== Agent Session: {question.uri if question else 'Unknown'} ===") + print() + + if question: + print(f"Question: {question.query}") + if question.timestamp: + print(f"Time: {question.timestamp}") + print() + + # Analysis steps + print("--- Analysis ---") + iterations = trace.get("iterations", []) + if iterations: + for i, analysis in enumerate(iterations, 1): + print(f"Analysis {i}:") + print(f" Thought: {analysis.thought or 'N/A'}") + print(f" Action: {analysis.action or 'N/A'}") + + if analysis.arguments: + # Try to pretty-print JSON arguments + try: + args_obj = json.loads(analysis.arguments) + args_str = json.dumps(args_obj, indent=4) + print(f" Arguments:") + for line in args_str.split('\n'): + print(f" {line}") + except Exception: + print(f" Arguments: {analysis.arguments}") + else: + print(f" Arguments: N/A") + + obs = analysis.observation or 'N/A' + if obs and len(obs) > 200: + obs = obs[:200] + "... [truncated]" + print(f" Observation: {obs}") + print() + else: + print("No analysis steps recorded") + print() + + # Conclusion + print("--- Conclusion ---") + conclusion = trace.get("conclusion") + if conclusion and conclusion.answer: + print("Answer:") + for line in conclusion.answer.split("\n"): + print(f" {line}") + else: + print("No conclusion recorded") + + +def trace_to_dict(trace, trace_type): + """Convert trace entities to JSON-serializable dict.""" + if trace_type == "agent": + question = trace.get("question") + return { + "type": "agent", + "session_id": question.uri if question else None, + "question": question.query if question else None, + "time": question.timestamp if question else None, + "iterations": [ + { + "id": a.uri, + "thought": a.thought, + "action": a.action, + "arguments": a.arguments, + "observation": a.observation, + } + for a in trace.get("iterations", []) + ], + "conclusion": { + "id": trace["conclusion"].uri, + "answer": trace["conclusion"].answer, + } if trace.get("conclusion") else None, + } + elif trace_type == "docrag": + question = trace.get("question") + exploration = trace.get("exploration") + synthesis = trace.get("synthesis") + + return { + "type": "docrag", + "question_id": question.uri if question else None, + "question": question.query if question else None, + "time": question.timestamp if question else None, + "exploration": { + "id": exploration.uri, + "chunk_count": exploration.chunk_count, + } if exploration else None, + "synthesis": { + "id": synthesis.uri, + "document_uri": synthesis.document_uri, + "answer": synthesis.content, + } if synthesis else None, + } + else: + # graphrag + question = trace.get("question") + exploration = trace.get("exploration") + focus = trace.get("focus") + synthesis = trace.get("synthesis") + + return { + "type": "graphrag", + "question_id": question.uri if question else None, + "question": question.query if question else None, + "time": question.timestamp if question else None, + "exploration": { + "id": exploration.uri, + "edge_count": exploration.edge_count, + } if exploration else None, + "focus": { + "id": focus.uri, + "selected_edges": [ + { + "edge": edge_sel.edge, + "reasoning": edge_sel.reasoning, + } + for edge_sel in focus.edge_selections + ], + } if focus else None, + "synthesis": { + "id": synthesis.uri, + "document_uri": synthesis.document_uri, + "answer": synthesis.content, + } if synthesis else None, + } def main(): @@ -727,50 +470,69 @@ def main(): try: api = Api(args.api_url, token=args.token) socket = api.socket() + flow = socket.flow(args.flow_id) + explain_client = ExplainabilityClient(flow) try: - # Detect trace type (agent vs graphrag) - trace_type = detect_trace_type( - socket=socket, - flow_id=args.flow_id, + # Detect trace type + trace_type = explain_client.detect_session_type( + args.question_id, + graph=RETRIEVAL_GRAPH, user=args.user, collection=args.collection, - entity_id=args.question_id, ) if trace_type == "agent": - # Build and print agent trace - trace = build_agent_trace( - socket=socket, - flow_id=args.flow_id, + # Fetch and display agent trace + trace = explain_client.fetch_agent_trace( + args.question_id, + graph=RETRIEVAL_GRAPH, user=args.user, collection=args.collection, - session_id=args.question_id, api=api, - max_answer=args.max_answer, + max_content=args.max_answer, ) if args.format == 'json': - print_agent_json(trace) + print(json.dumps(trace_to_dict(trace, "agent"), indent=2)) else: print_agent_text(trace) - else: - # Build and print GraphRAG trace (existing behavior) - trace = build_trace( - socket=socket, - flow_id=args.flow_id, + + elif trace_type == "docrag": + # Fetch and display DocRAG trace + trace = explain_client.fetch_docrag_trace( + args.question_id, + graph=RETRIEVAL_GRAPH, user=args.user, collection=args.collection, - question_id=args.question_id, api=api, - show_provenance=args.show_provenance, - max_answer=args.max_answer, + max_content=args.max_answer, ) if args.format == 'json': - print_json(trace) + print(json.dumps(trace_to_dict(trace, "docrag"), indent=2)) else: - print_text(trace, show_provenance=args.show_provenance) + print_docrag_text(trace) + + else: + # Fetch and display GraphRAG trace + trace = explain_client.fetch_graphrag_trace( + args.question_id, + graph=RETRIEVAL_GRAPH, + user=args.user, + collection=args.collection, + api=api, + max_content=args.max_answer, + ) + + if args.format == 'json': + print(json.dumps(trace_to_dict(trace, "graphrag"), indent=2)) + else: + print_graphrag_text( + trace, explain_client, flow, + args.user, args.collection, + show_provenance=args.show_provenance + ) finally: socket.close() diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 2ae93b05..306c081e 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -2,6 +2,8 @@ Simple agent infrastructure broadly implements the ReAct flow. """ +import asyncio +import base64 import json import re import sys @@ -17,9 +19,13 @@ from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec from ... base import ProducerSpec +from ... base import Consumer, Producer +from ... base import ConsumerMetrics, ProducerMetrics from ... schema import AgentRequest, AgentResponse, AgentStep, Error from ... schema import Triples, Metadata +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue # Provenance imports for agent explainability from trustgraph.provenance import ( @@ -41,6 +47,8 @@ from . types import Final, Action, Tool, Argument default_ident = "agent-manager" default_max_iterations = 10 +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue class Processor(AgentService): @@ -129,6 +137,115 @@ class Processor(AgentService): ) ) + # Librarian client for storing answer content + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request" + ) + + self.librarian_request_producer = Producer( + backend=self.pubsub, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self.on_librarian_response, + metrics=librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_librarian_requests = {} + + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id in self.pending_librarian_requests: + future = self.pending_librarian_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + """ + Save answer content to the librarian. + + Args: + doc_id: ID for the answer document + user: User ID + content: Answer text content + title: Optional title + timeout: Request timeout in seconds + + Returns: + The document ID on success + """ + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or "Agent Answer", + document_type="answer", + ) + + request = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_librarian_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving answer: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_librarian_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving answer document {doc_id}") + async def on_tools_config(self, config, version): logger.info(f"Loading configuration version {version}") @@ -347,6 +464,15 @@ class Processor(AgentService): )) logger.debug(f"Emitted session triples for {session_uri}") + # Send explain event for session + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=session_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + logger.info(f"Question: {request.question}") if len(history) >= self.max_iterations: @@ -504,8 +630,28 @@ class Processor(AgentService): else: parent_uri = session_uri + # Save answer to librarian + answer_doc_id = None + if f: + answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer" + try: + await self.save_answer_content( + doc_id=answer_doc_id, + user=request.user, + content=f, + title=f"Agent Answer: {request.question[:50]}...", + ) + logger.debug(f"Saved answer to librarian: {answer_doc_id}") + except Exception as e: + logger.warning(f"Failed to save answer to librarian: {e}") + answer_doc_id = None # Fall back to inline content + final_triples = set_graph( - agent_final_triples(final_uri, parent_uri, f), + agent_final_triples( + final_uri, parent_uri, + answer="" if answer_doc_id else f, + document_id=answer_doc_id, + ), GRAPH_RETRIEVAL ) await flow("explainability").send(Triples( @@ -518,6 +664,15 @@ class Processor(AgentService): )) logger.debug(f"Emitted final triples for {final_uri}") + # Send explain event for conclusion + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=final_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + if streaming: # Streaming format - send end-of-dialog marker # Answer chunks were already sent via answer() callback during parsing @@ -558,14 +713,48 @@ class Processor(AgentService): else: parent_uri = session_uri + # Save thought to librarian + thought_doc_id = None + if act.thought: + thought_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" + try: + await self.save_answer_content( + doc_id=thought_doc_id, + user=request.user, + content=act.thought, + title=f"Agent Thought: {act.name}", + ) + logger.debug(f"Saved thought to librarian: {thought_doc_id}") + except Exception as e: + logger.warning(f"Failed to save thought to librarian: {e}") + thought_doc_id = None + + # Save observation to librarian + observation_doc_id = None + if act.observation: + observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" + try: + await self.save_answer_content( + doc_id=observation_doc_id, + user=request.user, + content=act.observation, + title=f"Agent Observation: {act.name}", + ) + logger.debug(f"Saved observation to librarian: {observation_doc_id}") + except Exception as e: + logger.warning(f"Failed to save observation to librarian: {e}") + observation_doc_id = None + iter_triples = set_graph( agent_iteration_triples( iteration_uri, parent_uri, - act.thought, - act.name, - act.arguments, - act.observation, + thought="" if thought_doc_id else act.thought, + action=act.name, + arguments=act.arguments, + observation="" if observation_doc_id else act.observation, + thought_document_id=thought_doc_id, + observation_document_id=observation_doc_id, ), GRAPH_RETRIEVAL ) @@ -579,6 +768,15 @@ class Processor(AgentService): )) logger.debug(f"Emitted iteration triples for {iteration_uri}") + # Send explain event for iteration + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=iteration_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + history.append(act) # Handle state transitions if tool execution was successful diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 7730ceac..78c97024 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -109,7 +109,7 @@ class DocumentRag: async def query( self, query, user="trustgraph", collection="default", doc_limit=20, streaming=False, chunk_callback=None, - explain_callback=None, + explain_callback=None, save_answer_callback=None, ): """ Execute a Document RAG query with optional explainability tracking. @@ -122,6 +122,7 @@ class DocumentRag: streaming: Enable streaming LLM response chunk_callback: async def callback(chunk, end_of_stream) for streaming explain_callback: async def callback(triples, explain_id) for explainability + save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian Returns: str: The synthesized answer text @@ -192,9 +193,28 @@ class DocumentRag: # Emit synthesis explainability after answer generated if explain_callback: + synthesis_doc_id = None answer_text = resp if resp else "" + + # Save answer to librarian if callback provided + if save_answer_callback and answer_text: + # Generate document ID as URN matching query-time provenance format + synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer" + try: + await save_answer_callback(synthesis_doc_id, answer_text) + if self.verbose: + logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") + except Exception as e: + logger.warning(f"Failed to save answer to librarian: {e}") + synthesis_doc_id = None # Fall back to inline content + + # Generate triples with document reference or inline content syn_triples = set_graph( - docrag_synthesis_triples(syn_uri, exp_uri, answer_text), + docrag_synthesis_triples( + syn_uri, exp_uri, + answer_text="" if synthesis_doc_id else answer_text, + document_id=synthesis_doc_id, + ), GRAPH_RETRIEVAL ) await explain_callback(syn_triples, syn_uri) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index c1d96260..9eb32e12 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -8,8 +8,10 @@ import asyncio import base64 import logging +import uuid + from ... schema import DocumentRagQuery, DocumentRagResponse, Error -from ... schema import LibrarianRequest, LibrarianResponse +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples, Metadata from ... provenance import GRAPH_RETRIEVAL @@ -179,6 +181,62 @@ class Processor(FlowProcessor): self.pending_requests.pop(request_id, None) raise RuntimeError(f"Timeout fetching chunk {chunk_id}") + async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + """ + Save answer content to the librarian. + + Args: + doc_id: ID for the answer document + user: User ID + content: Answer text content + title: Optional title + timeout: Request timeout in seconds + + Returns: + The document ID on success + """ + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or "DocumentRAG Answer", + document_type="answer", + ) + + request = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving answer: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving answer document {doc_id}") + async def on_request(self, msg, consumer, flow): try: @@ -222,10 +280,20 @@ class Processor(FlowProcessor): response=None, explain_id=explain_id, explain_graph=GRAPH_RETRIEVAL, + message_type="explain", ), properties={"id": id} ) + # Callback to save answer content to librarian + async def save_answer(doc_id, answer_text): + await self.save_answer_content( + doc_id=doc_id, + user=v.user, + content=answer_text, + title=f"DocumentRAG Answer: {v.query[:50]}...", + ) + # Check if streaming is requested if v.streaming: # Define async callback for streaming chunks @@ -235,6 +303,7 @@ class Processor(FlowProcessor): DocumentRagResponse( response=chunk, end_of_stream=end_of_stream, + message_type="chunk", error=None ), properties={"id": id} @@ -250,6 +319,17 @@ class Processor(FlowProcessor): streaming=True, chunk_callback=send_chunk, explain_callback=send_explainability, + save_answer_callback=save_answer, + ) + + # Send end_of_session to signal entire session is complete + await flow("response").send( + DocumentRagResponse( + response=None, + end_of_session=True, + message_type="end", + ), + properties={"id": id} ) else: # Non-streaming path (existing behavior) @@ -259,6 +339,7 @@ class Processor(FlowProcessor): collection=v.collection, doc_limit=doc_limit, explain_callback=send_explainability, + save_answer_callback=save_answer, ) await flow("response").send(