diff --git a/specs/api/components/schemas/rag/GraphRagResponse.yaml b/specs/api/components/schemas/rag/GraphRagResponse.yaml index 75f4f059..de6c74cc 100644 --- a/specs/api/components/schemas/rag/GraphRagResponse.yaml +++ b/specs/api/components/schemas/rag/GraphRagResponse.yaml @@ -1,13 +1,27 @@ type: object -description: Graph RAG response +description: Graph RAG response message properties: + message_type: + type: string + description: Type of message - "chunk" for LLM response chunks, "provenance" for provenance announcements + enum: [chunk, provenance] + example: chunk response: type: string - description: Generated response based on retrieved knowledge graph + description: Generated response text (for chunk messages) example: Quantum physics and computer science intersect in quantum computing... - end-of-stream: + provenance_id: + type: string + description: Provenance node URI (for provenance messages) + example: urn:trustgraph:session:abc123 + end_of_stream: type: boolean - description: Indicates streaming is complete (streaming mode) + description: Indicates LLM response stream is complete + default: false + example: true + end_of_session: + type: boolean + description: Indicates entire session is complete (all messages sent) default: false example: true error: diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py index c01156ae..a92705a0 100644 --- a/tests/contract/test_translator_completion_flags.py +++ b/tests/contract/test_translator_completion_flags.py @@ -17,16 +17,18 @@ from trustgraph.messaging import TranslatorRegistry class TestRAGTranslatorCompletionFlags: """Contract tests for RAG response translator completion flags""" - def test_graph_rag_translator_is_final_with_end_of_stream_true(self): + def test_graph_rag_translator_is_final_with_end_of_session_true(self): """ Test that GraphRagResponseTranslator returns is_final=True - when end_of_stream=True. + when end_of_session=True. """ # Arrange translator = TranslatorRegistry.get_response_translator("graph-rag") response = GraphRagResponse( response="A small domesticated mammal.", + message_type="chunk", end_of_stream=True, + end_of_session=True, error=None ) @@ -34,20 +36,23 @@ class TestRAGTranslatorCompletionFlags: response_dict, is_final = translator.from_response_with_completion(response) # Assert - assert is_final is True, "is_final must be True when end_of_stream=True" + assert is_final is True, "is_final must be True when end_of_session=True" assert response_dict["response"] == "A small domesticated mammal." - assert response_dict["end_of_stream"] is True + assert response_dict["end_of_session"] is True + assert response_dict["message_type"] == "chunk" - def test_graph_rag_translator_is_final_with_end_of_stream_false(self): + def test_graph_rag_translator_is_final_with_end_of_session_false(self): """ Test that GraphRagResponseTranslator returns is_final=False - when end_of_stream=False. + when end_of_session=False (even if end_of_stream=True). """ # Arrange translator = TranslatorRegistry.get_response_translator("graph-rag") response = GraphRagResponse( response="Chunk 1", + message_type="chunk", end_of_stream=False, + end_of_session=False, error=None ) @@ -55,9 +60,55 @@ class TestRAGTranslatorCompletionFlags: response_dict, is_final = translator.from_response_with_completion(response) # Assert - assert is_final is False, "is_final must be False when end_of_stream=False" + assert is_final is False, "is_final must be False when end_of_session=False" assert response_dict["response"] == "Chunk 1" - assert response_dict["end_of_stream"] is False + assert response_dict["end_of_session"] is False + + def test_graph_rag_translator_provenance_message(self): + """ + Test that GraphRagResponseTranslator handles provenance messages. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("graph-rag") + response = GraphRagResponse( + response="", + message_type="explain", + explain_id="urn:trustgraph:session:abc123", + end_of_stream=False, + end_of_session=False, + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False + assert response_dict["message_type"] == "explain" + assert response_dict["explain_id"] == "urn:trustgraph:session:abc123" + + def test_graph_rag_translator_end_of_stream_not_final(self): + """ + Test that end_of_stream=True alone does NOT make is_final=True. + The session continues with provenance messages after LLM stream completes. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("graph-rag") + response = GraphRagResponse( + response="Final chunk", + message_type="chunk", + end_of_stream=True, + end_of_session=False, # Session continues with provenance + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False, "end_of_stream=True should NOT make is_final=True" + assert response_dict["end_of_stream"] is True + assert response_dict["end_of_session"] is False def test_document_rag_translator_is_final_with_end_of_stream_true(self): """ diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index d7e39a2e..f0de5bb5 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -83,13 +83,25 @@ class TestGraphRagIntegration: @pytest.fixture def mock_prompt_client(self): - """Mock prompt client that generates realistic responses""" + """Mock prompt client that generates realistic responses for two-step process""" client = AsyncMock() - client.kg_prompt.return_value = ( - "Machine learning is a subset of artificial intelligence that enables computers " - "to learn from data without being explicitly programmed. It uses algorithms " - "and statistical models to find patterns in data." - ) + + # Mock responses for the two-step process: + # 1. kg-edge-selection returns JSONL with edge IDs + # 2. kg-synthesis returns the final answer + async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): + if prompt_name == "kg-edge-selection": + # Return empty selection (no edges selected) - valid JSONL + return "" + elif prompt_name == "kg-synthesis": + return ( + "Machine learning is a subset of artificial intelligence that enables computers " + "to learn from data without being explicitly programmed. It uses algorithms " + "and statistical models to find patterns in data." + ) + return "" + + client.prompt.side_effect = mock_prompt return client @pytest.fixture @@ -108,7 +120,7 @@ class TestGraphRagIntegration: async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client, mock_graph_embeddings_client, mock_triples_client, mock_prompt_client): - """Test complete GraphRAG pipeline from query to response""" + """Test complete GraphRAG pipeline from query to response with real-time provenance""" # Arrange query = "What is machine learning?" user = "test_user" @@ -116,13 +128,20 @@ class TestGraphRagIntegration: entity_limit = 50 triple_limit = 30 + # Collect provenance events + provenance_events = [] + + async def collect_provenance(triples, prov_id): + provenance_events.append((triples, prov_id)) + # Act - result = await graph_rag.query( + response = await graph_rag.query( query=query, user=user, collection=collection, entity_limit=entity_limit, - triple_limit=triple_limit + triple_limit=triple_limit, + explain_callback=collect_provenance ) # Assert - Verify service coordination @@ -141,16 +160,19 @@ class TestGraphRagIntegration: # 3. Should query triples to build knowledge subgraph assert mock_triples_client.query_stream.call_count > 0 - # 4. Should call prompt with knowledge graph - mock_prompt_client.kg_prompt.assert_called_once() - call_args = mock_prompt_client.kg_prompt.call_args - assert call_args.args[0] == query # First arg is query - assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples) + # 4. Should call prompt twice (edge selection + synthesis) + assert mock_prompt_client.prompt.call_count == 2 # Verify final response - assert result is not None - assert isinstance(result, str) - assert "machine learning" in result.lower() + assert response is not None + assert isinstance(response, str) + assert "machine learning" in response.lower() + + # Verify provenance was emitted in real-time (4 events: session, retrieval, selection, answer) + assert len(provenance_events) == 4 + for triples, prov_id in provenance_events: + assert isinstance(triples, list) + assert prov_id.startswith("urn:trustgraph:") @pytest.mark.asyncio async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client, @@ -206,19 +228,25 @@ class TestGraphRagIntegration: mock_graph_embeddings_client.query.return_value = [] # No entities found mock_triples_client.query_stream.return_value = [] # No triples found + # Collect provenance + provenance_events = [] + + async def collect_provenance(triples, prov_id): + provenance_events.append((triples, prov_id)) + # Act - result = await graph_rag.query( + response = await graph_rag.query( query="unknown topic", user="test_user", - collection="test_collection" + collection="test_collection", + explain_callback=collect_provenance ) # Assert - # Should still call prompt client with empty knowledge graph - mock_prompt_client.kg_prompt.assert_called_once() - call_args = mock_prompt_client.kg_prompt.call_args - assert isinstance(call_args.args[1], list) # kg should be a list - assert result is not None + # Should still call prompt client (twice: edge selection + synthesis) + assert response is not None + # Provenance should still be emitted (4 events) + assert len(provenance_events) == 4 @pytest.mark.asyncio async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client): diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 99880510..d936de11 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -53,30 +53,34 @@ class TestGraphRagStreaming: @pytest.fixture def mock_streaming_prompt_client(self, mock_streaming_llm_response): - """Mock prompt client with streaming support""" + """Mock prompt client with streaming support for two-stage GraphRAG""" client = AsyncMock() - async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None): - # Both modes return the same text - full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." + # Full synthesis text + full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." - if streaming and chunk_callback: - # Simulate streaming chunks with end_of_stream flags - chunks = [] - async for chunk in mock_streaming_llm_response(): - chunks.append(chunk) + async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): + if prompt_id == "kg-edge-selection": + # Edge selection returns JSONL with IDs - simulate selecting first edge + return '{"id": "abc12345", "reasoning": "Relevant to query"}\n' + elif prompt_id == "kg-synthesis": + if streaming and chunk_callback: + # Simulate streaming chunks with end_of_stream flags + chunks = [] + async for chunk in mock_streaming_llm_response(): + chunks.append(chunk) - # Send all chunks with end_of_stream=False except the last - for i, chunk in enumerate(chunks): - is_final = (i == len(chunks) - 1) - await chunk_callback(chunk, is_final) + # Send all chunks with end_of_stream=False except the last + for i, chunk in enumerate(chunks): + is_final = (i == len(chunks) - 1) + await chunk_callback(chunk, is_final) - return full_text - else: - # Non-streaming response - same text - return full_text + return full_text + else: + return full_text + return "" - client.kg_prompt.side_effect = kg_prompt_side_effect + client.prompt.side_effect = prompt_side_effect return client @pytest.fixture @@ -93,18 +97,25 @@ class TestGraphRagStreaming: @pytest.mark.asyncio async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector): - """Test basic GraphRAG streaming functionality""" + """Test basic GraphRAG streaming functionality with real-time provenance""" # Arrange query = "What is machine learning?" collector = streaming_chunk_collector() - # Act - result = await graph_rag_streaming.query( + # Collect provenance events + provenance_events = [] + + async def collect_provenance(triples, prov_id): + provenance_events.append((triples, prov_id)) + + # Act - query() returns response, provenance via callback + response = await graph_rag_streaming.query( query=query, user="test_user", collection="test_collection", streaming=True, - chunk_callback=collector.collect + chunk_callback=collector.collect, + explain_callback=collect_provenance ) # Assert @@ -116,10 +127,15 @@ class TestGraphRagStreaming: # Verify full response matches concatenated chunks full_from_chunks = collector.get_full_text() - assert result == full_from_chunks + assert response == full_from_chunks # Verify content is reasonable - assert "machine" in result.lower() or "learning" in result.lower() + assert "machine" in response.lower() or "learning" in response.lower() + + # Verify provenance was emitted in real-time (4 events) + assert len(provenance_events) == 4 + for triples, prov_id in provenance_events: + assert prov_id.startswith("urn:trustgraph:") @pytest.mark.asyncio async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming): @@ -130,7 +146,7 @@ class TestGraphRagStreaming: collection = "test_collection" # Act - Non-streaming - non_streaming_result = await graph_rag_streaming.query( + non_streaming_response = await graph_rag_streaming.query( query=query, user=user, collection=collection, @@ -143,7 +159,7 @@ class TestGraphRagStreaming: async def collect(chunk, end_of_stream): streaming_chunks.append(chunk) - streaming_result = await graph_rag_streaming.query( + streaming_response = await graph_rag_streaming.query( query=query, user=user, collection=collection, @@ -152,9 +168,9 @@ class TestGraphRagStreaming: ) # Assert - Results should be equivalent - assert streaming_result == non_streaming_result + assert streaming_response == non_streaming_response assert len(streaming_chunks) > 0 - assert "".join(streaming_chunks) == streaming_result + assert "".join(streaming_chunks) == streaming_response @pytest.mark.asyncio async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming): @@ -163,7 +179,7 @@ class TestGraphRagStreaming: callback = AsyncMock() # Act - result = await graph_rag_streaming.query( + response = await graph_rag_streaming.query( query="test query", user="test_user", collection="test_collection", @@ -173,7 +189,7 @@ class TestGraphRagStreaming: # Assert assert callback.call_count > 0 - assert result is not None + assert response is not None # Verify all callback invocations had string arguments for call in callback.call_args_list: @@ -183,7 +199,7 @@ class TestGraphRagStreaming: async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming): """Test streaming parameter without callback (should fall back to non-streaming)""" # Arrange & Act - result = await graph_rag_streaming.query( + response = await graph_rag_streaming.query( query="test query", user="test_user", collection="test_collection", @@ -192,8 +208,8 @@ class TestGraphRagStreaming: ) # Assert - Should complete without error - assert result is not None - assert isinstance(result, str) + assert response is not None + assert isinstance(response, str) @pytest.mark.asyncio async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming, @@ -204,7 +220,7 @@ class TestGraphRagStreaming: callback = AsyncMock() # Act - result = await graph_rag_streaming.query( + response = await graph_rag_streaming.query( query="unknown topic", user="test_user", collection="test_collection", @@ -213,7 +229,7 @@ class TestGraphRagStreaming: ) # Assert - Should still produce streamed response - assert result is not None + assert response is not None assert callback.call_count > 0 @pytest.mark.asyncio diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 4fa93afd..f5fe14b5 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -44,18 +44,23 @@ class TestGraphRagStreamingProtocol: """Mock prompt client that simulates realistic streaming with end_of_stream flags""" client = AsyncMock() - async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None): - if streaming and chunk_callback: - # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True - await chunk_callback("The", False) - await chunk_callback(" answer", False) - await chunk_callback(" is here.", False) - await chunk_callback("", True) # Empty final chunk with end_of_stream=True - return "" # Return value not used since callback handles everything - else: - return "The answer is here." + async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None): + if prompt_name == "kg-edge-selection": + # Edge selection returns empty (no edges selected) + return "" + elif prompt_name == "kg-synthesis": + if streaming and chunk_callback: + # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True + await chunk_callback("The", False) + await chunk_callback(" answer", False) + await chunk_callback(" is here.", False) + await chunk_callback("", True) # Empty final chunk with end_of_stream=True + return "" # Return value not used since callback handles everything + else: + return "The answer is here." + return "" - client.kg_prompt.side_effect = kg_prompt_side_effect + client.prompt.side_effect = prompt_side_effect return client @pytest.fixture @@ -327,20 +332,24 @@ class TestStreamingProtocolEdgeCases: # Arrange client = AsyncMock() - async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None): - if streaming and chunk_callback: - await chunk_callback("text", False) - await chunk_callback("", False) # Empty but not final - await chunk_callback("more", False) - await chunk_callback("", True) # Empty and final + async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None): + if prompt_name == "kg-edge-selection": return "" - else: - return "textmore" + elif prompt_name == "kg-synthesis": + if streaming and chunk_callback: + await chunk_callback("text", False) + await chunk_callback("", False) # Empty but not final + await chunk_callback("more", False) + await chunk_callback("", True) # Empty and final + return "" + else: + return "textmore" + return "" - client.kg_prompt.side_effect = kg_prompt_with_empties + client.prompt.side_effect = prompt_with_empties rag = GraphRag( - embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])), + embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])), graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])), triples_client=AsyncMock(query=AsyncMock(return_value=[])), prompt_client=client, diff --git a/tests/unit/test_direct/test_entity_centric_kg.py b/tests/unit/test_direct/test_entity_centric_kg.py index df70efee..72c66a42 100644 --- a/tests/unit/test_direct/test_entity_centric_kg.py +++ b/tests/unit/test_direct/test_entity_centric_kg.py @@ -547,21 +547,21 @@ class TestServiceHelperFunctions: """Test cases for helper functions in service.py""" def test_create_term_with_uri_otype(self): - """Test create_term creates IRI Term for otype='u'""" + """Test create_term creates IRI Term for term_type='u'""" from trustgraph.query.triples.cassandra.service import create_term from trustgraph.schema import IRI - term = create_term('http://example.org/Alice', otype='u') + term = create_term('http://example.org/Alice', term_type='u') assert term.type == IRI assert term.iri == 'http://example.org/Alice' def test_create_term_with_literal_otype(self): - """Test create_term creates LITERAL Term for otype='l'""" + """Test create_term creates LITERAL Term for term_type='l'""" from trustgraph.query.triples.cassandra.service import create_term from trustgraph.schema import LITERAL - term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en') + term = create_term('Alice Smith', term_type='l', datatype='xsd:string', language='en') assert term.type == LITERAL assert term.value == 'Alice Smith' @@ -569,7 +569,7 @@ class TestServiceHelperFunctions: assert term.language == 'en' def test_create_term_with_triple_otype(self): - """Test create_term creates TRIPLE Term for otype='t' with valid JSON""" + """Test create_term creates TRIPLE Term for term_type='t' with valid JSON""" from trustgraph.query.triples.cassandra.service import create_term from trustgraph.schema import TRIPLE, IRI import json @@ -581,7 +581,7 @@ class TestServiceHelperFunctions: "o": {"type": "i", "iri": "http://example.org/Bob"}, }) - term = create_term(triple_json, otype='t') + term = create_term(triple_json, term_type='t') assert term.type == TRIPLE assert term.triple is not None diff --git a/tests/unit/test_gateway/test_streaming_translators.py b/tests/unit/test_gateway/test_streaming_translators.py index e767edd4..e190fe68 100644 --- a/tests/unit/test_gateway/test_streaming_translators.py +++ b/tests/unit/test_gateway/test_streaming_translators.py @@ -96,20 +96,21 @@ class TestGraphRagResponseTranslator: assert is_final is False assert result["end_of_stream"] is False - # Test final chunk with empty content + # Test final message with end_of_session=True final_response = GraphRagResponse( response="", end_of_stream=True, + end_of_session=True, error=None ) # Act result, is_final = translator.from_response_with_completion(final_response) - # Assert + # Assert - is_final is based on end_of_session, not end_of_stream assert is_final is True assert result["response"] == "" - assert result["end_of_stream"] is True + assert result["end_of_session"] is True class TestDocumentRagResponseTranslator: diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index 480f2ee1..b620df7e 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -118,8 +118,8 @@ class TestCassandraQueryProcessor: # Verify result contains the queried triple assert len(result) == 1 - assert result[0].s.value == 'test_subject' - assert result[0].p.value == 'test_predicate' + assert result[0].s.iri == 'test_subject' + assert result[0].p.iri == 'test_predicate' assert result[0].o.value == 'test_object' def test_processor_initialization_with_defaults(self): @@ -182,8 +182,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 - assert result[0].s.value == 'test_subject' - assert result[0].p.value == 'test_predicate' + assert result[0].s.iri == 'test_subject' + assert result[0].p.iri == 'test_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio @@ -219,8 +219,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 - assert result[0].s.value == 'test_subject' - assert result[0].p.value == 'result_predicate' + assert result[0].s.iri == 'test_subject' + assert result[0].p.iri == 'result_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio @@ -256,8 +256,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 - assert result[0].s.value == 'result_subject' - assert result[0].p.value == 'test_predicate' + assert result[0].s.iri == 'result_subject' + assert result[0].p.iri == 'test_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio @@ -293,8 +293,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 - assert result[0].s.value == 'result_subject' - assert result[0].p.value == 'result_predicate' + assert result[0].s.iri == 'result_subject' + assert result[0].p.iri == 'result_predicate' assert result[0].o.value == 'test_object' @pytest.mark.asyncio @@ -331,8 +331,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 - assert result[0].s.value == 'all_subject' - assert result[0].p.value == 'all_predicate' + assert result[0].s.iri == 'all_subject' + assert result[0].p.iri == 'all_predicate' assert result[0].o.value == 'all_object' def test_add_args_method(self): @@ -637,8 +637,8 @@ class TestCassandraQueryPerformanceOptimizations: ) assert len(result) == 1 - assert result[0].s.value == 'result_subject' - assert result[0].p.value == 'test_predicate' + assert result[0].s.iri == 'result_subject' + assert result[0].p.iri == 'test_predicate' assert result[0].o.value == 'test_object' @pytest.mark.asyncio @@ -678,8 +678,8 @@ class TestCassandraQueryPerformanceOptimizations: ) assert len(result) == 1 - assert result[0].s.value == 'test_subject' - assert result[0].p.value == 'result_predicate' + assert result[0].s.iri == 'test_subject' + assert result[0].p.iri == 'result_predicate' assert result[0].o.value == 'test_object' @pytest.mark.asyncio @@ -802,7 +802,7 @@ class TestCassandraQueryPerformanceOptimizations: # Verify all results were returned assert len(result) == 5 for i, triple in enumerate(result): - assert triple.s.value == f'subject_{i}' # Mock returns literal values + assert triple.s.iri == f'subject_{i}' # Mock returns literal values assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' assert triple.p.type == IRI assert triple.o.iri == 'http://example.com/Person' # URIs use .iri diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index eddc1e12..af0cbfc2 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -540,41 +540,68 @@ class TestQuery: query.maybe_label = AsyncMock(side_effect=mock_maybe_label) # Call get_labelgraph - result = await query.get_labelgraph("test query") - + labeled_edges, uri_map = await query.get_labelgraph("test query") + # Verify get_subgraph was called query.get_subgraph.assert_called_once_with("test query") - + # Verify label triples are filtered out - assert len(result) == 2 # Label triple should be excluded - + assert len(labeled_edges) == 2 # Label triple should be excluded + # Verify maybe_label was called for non-label triples expected_calls = [ (("entity1",), {}), (("predicate1",), {}), (("object1",), {}), (("entity3",), {}), (("predicate3",), {}), (("object3",), {}) ] assert query.maybe_label.call_count == 6 - + # Verify result contains human-readable labels - expected_result = [ + expected_edges = [ ("Human Entity One", "Human Predicate One", "Human Object One"), ("Human Entity Three", "Human Predicate Three", "Human Object Three") ] - assert result == expected_result + assert labeled_edges == expected_edges + + # Verify uri_map maps labeled edges back to original URIs + assert len(uri_map) == 2 @pytest.mark.asyncio async def test_graph_rag_query_method(self): - """Test GraphRag.query method orchestrates full RAG pipeline""" + """Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance""" + import json + from trustgraph.retrieval.graph_rag.graph_rag import edge_id + # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock() mock_triples_client = AsyncMock() - - # Mock prompt client response + + # Mock prompt client responses for two-step process expected_response = "This is the RAG response" - mock_prompt_client.kg_prompt.return_value = expected_response - + test_labelgraph = [("Subject", "Predicate", "Object")] + + # Compute the edge ID for the test edge + test_edge_id = edge_id("Subject", "Predicate", "Object") + + # Create uri_map for the test edge (maps labeled edge ID to original URIs) + test_uri_map = { + test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") + } + + # Mock edge selection response (JSONL format) + edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"}) + + # Configure prompt mock to return different responses based on prompt name + async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): + if prompt_name == "kg-edge-selection": + return edge_selection_response + elif prompt_name == "kg-synthesis": + return expected_response + return "" + + mock_prompt_client.prompt = mock_prompt + # Initialize GraphRag graph_rag = GraphRag( prompt_client=mock_prompt_client, @@ -583,39 +610,55 @@ class TestQuery: triples_client=mock_triples_client, verbose=False ) - - # Mock the Query class behavior by patching get_labelgraph - test_labelgraph = [("Subject", "Predicate", "Object")] - + # We need to patch the Query class's get_labelgraph method original_query_init = Query.__init__ original_get_labelgraph = Query.get_labelgraph - + def mock_query_init(self, *args, **kwargs): original_query_init(self, *args, **kwargs) - + async def mock_get_labelgraph(self, query_text): - return test_labelgraph - + return test_labelgraph, test_uri_map + Query.__init__ = mock_query_init Query.get_labelgraph = mock_get_labelgraph - + + # Collect provenance emitted via callback + provenance_events = [] + + async def collect_provenance(triples, prov_id): + provenance_events.append((triples, prov_id)) + try: - # Call GraphRag.query - result = await graph_rag.query( + # Call GraphRag.query with provenance callback + response = await graph_rag.query( query="test query", user="test_user", collection="test_collection", entity_limit=25, - triple_limit=15 + triple_limit=15, + explain_callback=collect_provenance ) - - # Verify prompt client was called with knowledge graph and query - mock_prompt_client.kg_prompt.assert_called_once_with("test query", test_labelgraph) - - # Verify result - assert result == expected_response - + + # Verify response text + assert response == expected_response + + # Verify provenance was emitted incrementally (4 events: session, retrieval, selection, answer) + assert len(provenance_events) == 4 + + # Verify each event has triples and a URN + for triples, prov_id in provenance_events: + assert isinstance(triples, list) + assert len(triples) > 0 + assert prov_id.startswith("urn:trustgraph:") + + # Verify order: session, retrieval, selection, answer + assert "session" in provenance_events[0][1] + assert "retrieval" in provenance_events[1][1] + assert "selection" in provenance_events[2][1] + assert "answer" in provenance_events[3][1] + finally: # Restore original methods Query.__init__ = original_query_init diff --git a/tests/unit/test_retrieval/test_graph_rag_service.py b/tests/unit/test_retrieval/test_graph_rag_service.py index ddfdfa75..2cd62286 100644 --- a/tests/unit/test_retrieval/test_graph_rag_service.py +++ b/tests/unit/test_retrieval/test_graph_rag_service.py @@ -1,6 +1,7 @@ """ -Unit tests for GraphRAG service non-streaming mode. -Tests that end_of_stream flag is correctly set in non-streaming responses. +Unit tests for GraphRAG service message format. +Tests the new message protocol with message_type, explain_id, and end_of_session. +Real-time explainability emission via callback. """ import pytest @@ -11,16 +12,14 @@ from trustgraph.schema import GraphRagQuery, GraphRagResponse class TestGraphRagService: - """Test GraphRAG service non-streaming behavior""" + """Test GraphRAG service message protocol""" @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') @pytest.mark.asyncio - async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class): + async def test_non_streaming_sends_chunk_then_provenance_messages(self, mock_graph_rag_class): """ - Test that non-streaming mode sets end_of_stream=True in response. - - This is a regression test for the bug where non-streaming responses - didn't set end_of_stream, causing clients to hang waiting for more data. + Test that non-streaming mode sends real-time provenance messages + followed by chunk message with response. """ # Setup processor processor = Processor( @@ -32,10 +31,22 @@ class TestGraphRagService: max_path_length=2 ) - # Setup mock GraphRag instance + # Setup mock GraphRag instance that calls explain_callback mock_rag_instance = AsyncMock() mock_graph_rag_class.return_value = mock_rag_instance - mock_rag_instance.query.return_value = "A small domesticated mammal." + + # Mock query() to call the explain_callback with each provenance event + async def mock_query(**kwargs): + explain_callback = kwargs.get('explain_callback') + if explain_callback: + # Simulate real-time provenance emission + await explain_callback([], "urn:trustgraph:session:test") + await explain_callback([], "urn:trustgraph:prov:retrieval:test") + await explain_callback([], "urn:trustgraph:prov:selection:test") + await explain_callback([], "urn:trustgraph:prov:answer:test") + return "A small domesticated mammal." + + mock_rag_instance.query.side_effect = mock_query # Setup message with non-streaming request msg = MagicMock() @@ -47,7 +58,7 @@ class TestGraphRagService: triple_limit=30, max_subgraph_size=150, max_path_length=2, - streaming=False # Non-streaming mode + streaming=False ) msg.properties.return_value = {"id": "test-id"} @@ -55,30 +66,48 @@ class TestGraphRagService: consumer = MagicMock() flow = MagicMock() - # Mock flow to return AsyncMock for clients and response producer - mock_producer = AsyncMock() + mock_response_producer = AsyncMock() + mock_provenance_producer = AsyncMock() def flow_router(service_name): if service_name == "response": - return mock_producer - return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients + return mock_response_producer + elif service_name == "explainability": + return mock_provenance_producer + return AsyncMock() flow.side_effect = flow_router # Execute await processor.on_request(msg, consumer, flow) - # Verify: response was sent with end_of_stream=True - mock_producer.send.assert_called_once() - sent_response = mock_producer.send.call_args[0][0] - assert isinstance(sent_response, GraphRagResponse) - assert sent_response.response == "A small domesticated mammal." - assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True" - assert sent_response.error is None + # Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session) + assert mock_response_producer.send.call_count == 6 + + # First 4 messages are explain (emitted in real-time during query) + for i in range(4): + prov_msg = mock_response_producer.send.call_args_list[i][0][0] + assert prov_msg.message_type == "explain" + assert prov_msg.explain_id is not None + + # 5th message is chunk with response + chunk_msg = mock_response_producer.send.call_args_list[4][0][0] + assert chunk_msg.message_type == "chunk" + assert chunk_msg.response == "A small domesticated mammal." + assert chunk_msg.end_of_stream is True + + # 6th message is empty chunk with end_of_session=True + close_msg = mock_response_producer.send.call_args_list[5][0][0] + assert close_msg.message_type == "chunk" + assert close_msg.response == "" + assert close_msg.end_of_session is True + + # Verify provenance triples were sent to provenance queue + assert mock_provenance_producer.send.call_count == 4 @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') @pytest.mark.asyncio - async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class): + async def test_error_response_closes_session(self, mock_graph_rag_class): """ - Test that error responses in non-streaming mode set end_of_stream=True. + Test that error responses set end_of_session=True. """ # Setup processor processor = Processor( @@ -105,7 +134,7 @@ class TestGraphRagService: triple_limit=30, max_subgraph_size=150, max_path_length=2, - streaming=False # Non-streaming mode + streaming=False ) msg.properties.return_value = {"id": "test-id"} @@ -113,22 +142,93 @@ class TestGraphRagService: consumer = MagicMock() flow = MagicMock() - mock_producer = AsyncMock() + mock_response_producer = AsyncMock() + mock_provenance_producer = AsyncMock() def flow_router(service_name): if service_name == "response": - return mock_producer + return mock_response_producer + elif service_name == "explainability": + return mock_provenance_producer return AsyncMock() flow.side_effect = flow_router # Execute await processor.on_request(msg, consumer, flow) - # Verify: error response was sent without end_of_stream (not streaming mode) - mock_producer.send.assert_called_once() - sent_response = mock_producer.send.call_args[0][0] + # Verify: error response was sent with session closed + mock_response_producer.send.assert_called_once() + sent_response = mock_response_producer.send.call_args[0][0] assert isinstance(sent_response, GraphRagResponse) - assert sent_response.response is None + assert sent_response.message_type == "chunk" assert sent_response.error is not None assert sent_response.error.message == "Test error" - # Note: error responses in non-streaming mode don't set end_of_stream - # because streaming was never started + assert sent_response.end_of_stream is True + assert sent_response.end_of_session is True + + @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') + @pytest.mark.asyncio + async def test_no_provenance_sends_empty_chunk_to_close(self, mock_graph_rag_class): + """ + Test that when no provenance callback is invoked, an empty chunk closes the session. + """ + # Setup processor + processor = Processor( + taskgroup=MagicMock(), + id="test-processor", + entity_limit=50, + triple_limit=30, + max_subgraph_size=150, + max_path_length=2 + ) + + # Setup mock GraphRag instance that doesn't call provenance callback + mock_rag_instance = AsyncMock() + mock_graph_rag_class.return_value = mock_rag_instance + + async def mock_query(**kwargs): + # Don't call explain_callback + return "Response text" + + mock_rag_instance.query.side_effect = mock_query + + # Setup message + msg = MagicMock() + msg.value.return_value = GraphRagQuery( + query="Test query", + user="trustgraph", + collection="default", + streaming=False + ) + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + + mock_response_producer = AsyncMock() + mock_provenance_producer = AsyncMock() + def flow_router(service_name): + if service_name == "response": + return mock_response_producer + elif service_name == "explainability": + return mock_provenance_producer + return AsyncMock() + flow.side_effect = flow_router + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: 2 messages (chunk + empty chunk to close) + assert mock_response_producer.send.call_count == 2 + + # First is the response chunk + chunk_msg = mock_response_producer.send.call_args_list[0][0][0] + assert chunk_msg.message_type == "chunk" + assert chunk_msg.response == "Response text" + assert chunk_msg.end_of_stream is True + + # Second is empty chunk to close session + close_msg = mock_response_producer.send.call_args_list[1][0][0] + assert close_msg.message_type == "chunk" + assert close_msg.response == "" + assert close_msg.end_of_session is True diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 99938d5b..3279609b 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -110,15 +110,25 @@ class AsyncSocketClient: # Parse different chunk types chunk = self._parse_chunk(resp) - yield chunk + if chunk is not None: # Skip provenance messages in streaming + yield chunk - # Check if this is the final chunk - if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"): + # Check if this is the final message + # end_of_session indicates entire session is complete (including provenance) + # end_of_dialog is for agent dialogs + # complete is from the gateway envelope + if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): break def _parse_chunk(self, resp: Dict[str, Any]): - """Parse response chunk into appropriate type""" + """Parse response chunk into appropriate type. Returns None for non-content messages.""" chunk_type = resp.get("chunk_type") + message_type = resp.get("message_type") + + # Handle new GraphRAG message format with message_type + if message_type == "provenance": + # Provenance messages are not yielded to user - they're metadata + return None if chunk_type == "thought": return AgentThought( @@ -143,7 +153,7 @@ class AsyncSocketClient: end_of_message=resp.get("end_of_message", False) ) else: - # RAG-style chunk (or generic chunk) + # RAG-style chunk (or generic chunk with message_type="chunk") # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 700a4531..a08b8bca 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -11,7 +11,7 @@ import websockets from typing import Optional, Dict, Any, Iterator, Union, List from threading import Lock -from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk +from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent from . exceptions import ProtocolException, raise_from_error_dict @@ -310,15 +310,28 @@ class SocketClient: # Parse different chunk types chunk = self._parse_chunk(resp) - yield chunk + if chunk is not None: # Skip provenance messages in streaming + yield chunk - # Check if this is the final chunk - if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"): + # Check if this is the final message + # end_of_session indicates entire session is complete (including provenance) + # end_of_dialog is for agent dialogs + # complete is from the gateway envelope + if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): break - def _parse_chunk(self, resp: Dict[str, Any]) -> StreamingChunk: - """Parse response chunk into appropriate type""" + def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]: + """Parse response chunk into appropriate type. Returns None for non-content messages.""" chunk_type = resp.get("chunk_type") + message_type = resp.get("message_type") + + # Handle new GraphRAG message format with message_type + if message_type == "provenance": + if include_provenance: + # Return provenance event for explainability + return ProvenanceEvent(provenance_id=resp.get("provenance_id", "")) + # Provenance messages are not yielded to user - they're metadata + return None if chunk_type == "thought": return AgentThought( @@ -360,7 +373,7 @@ class SocketClient: end_of_dialog=resp.get("end_of_dialog", False) ) else: - # RAG-style chunk (or generic chunk) + # RAG-style chunk (or generic chunk with message_type="chunk") # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 47aa5ae0..f66f7b82 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -202,3 +202,29 @@ class RAGChunk(StreamingChunk): chunk_type: str = "rag" end_of_stream: bool = False error: Optional[Dict[str, str]] = None + +@dataclasses.dataclass +class ProvenanceEvent: + """ + Provenance event for explainability. + + Emitted during GraphRAG queries when explainable mode is enabled. + Each event represents a provenance node created during query processing. + + Attributes: + provenance_id: URI of the provenance node (e.g., urn:trustgraph:session:abc123) + event_type: Type of provenance event (session, retrieval, selection, answer) + """ + provenance_id: str + event_type: str = "" # Derived from provenance_id (session, retrieval, selection, answer) + + def __post_init__(self): + # Extract event type from provenance_id + if "session" in self.provenance_id: + self.event_type = "session" + elif "retrieval" in self.provenance_id: + self.event_type = "retrieval" + elif "selection" in self.provenance_id: + self.event_type = "selection" + elif "answer" in self.provenance_id: + self.event_type = "answer" diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 22166bd9..85900089 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -90,13 +90,31 @@ class GraphRagResponseTranslator(MessageTranslator): def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]: result = {} - # Include response content (even if empty string) + # Include message_type + message_type = getattr(obj, "message_type", "") + if message_type: + result["message_type"] = message_type + + # Include response content for chunk messages if obj.response is not None: result["response"] = obj.response - # Include end_of_stream flag + # Include explain_id for explain messages + explain_id = getattr(obj, "explain_id", None) + if explain_id: + result["explain_id"] = explain_id + + # Include explain_collection for explain messages + explain_collection = getattr(obj, "explain_collection", None) + if explain_collection: + result["explain_collection"] = explain_collection + + # Include end_of_stream flag (LLM stream complete) result["end_of_stream"] = getattr(obj, "end_of_stream", False) + # Include end_of_session flag (entire session complete) + result["end_of_session"] = getattr(obj, "end_of_session", False) + # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "type": obj.error.type} @@ -105,5 +123,6 @@ class GraphRagResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - is_final = getattr(obj, 'end_of_stream', False) + # Session is complete when end_of_session is True + is_final = getattr(obj, 'end_of_session', False) return self.from_pulsar(obj), is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index 3e80dad6..bca4c156 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -40,6 +40,11 @@ from . uris import ( activity_uri, statement_uri, agent_uri, + # Query-time provenance URIs + query_session_uri, + retrieval_uri, + selection_uri, + answer_uri, ) # Namespace constants @@ -58,6 +63,8 @@ from . namespaces import ( TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL, TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH, + # Query-time provenance predicates + TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_CONTENT, ) # Triple builders @@ -65,6 +72,11 @@ from . triples import ( document_triples, derived_entity_triples, triple_provenance_triples, + # Query-time provenance triple builders + query_session_triples, + retrieval_triples, + selection_triples, + answer_triples, ) # Vocabulary bootstrap @@ -86,6 +98,11 @@ __all__ = [ "activity_uri", "statement_uri", "agent_uri", + # Query-time provenance URIs + "query_session_uri", + "retrieval_uri", + "selection_uri", + "answer_uri", # Namespaces "PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT", "PROV_WAS_DERIVED_FROM", "PROV_WAS_GENERATED_BY", @@ -97,10 +114,17 @@ __all__ = [ "TG_CHUNK_SIZE", "TG_CHUNK_OVERLAP", "TG_COMPONENT_VERSION", "TG_LLM_MODEL", "TG_ONTOLOGY", "TG_EMBEDDING_MODEL", "TG_SOURCE_TEXT", "TG_SOURCE_CHAR_OFFSET", "TG_SOURCE_CHAR_LENGTH", + # Query-time provenance predicates + "TG_QUERY", "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_CONTENT", # Triple builders "document_triples", "derived_entity_triples", "triple_provenance_triples", + # Query-time provenance triple builders + "query_session_triples", + "retrieval_triples", + "selection_triples", + "answer_triples", # Vocabulary "get_vocabulary_triples", "PROV_CLASS_LABELS", diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index b207b38f..f348556e 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -58,3 +58,12 @@ TG_EMBEDDING_MODEL = TG + "embeddingModel" TG_SOURCE_TEXT = TG + "sourceText" TG_SOURCE_CHAR_OFFSET = TG + "sourceCharOffset" TG_SOURCE_CHAR_LENGTH = TG + "sourceCharLength" + +# Query-time provenance predicates +TG_QUERY = TG + "query" +TG_EDGE_COUNT = TG + "edgeCount" +TG_SELECTED_EDGE = TG + "selectedEdge" +TG_EDGE = TG + "edge" +TG_REASONING = TG + "reasoning" +TG_CONTENT = TG + "content" +TG_DOCUMENT = TG + "document" # Reference to document in librarian diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index cbb0e420..3e6abae8 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -17,9 +17,12 @@ from . namespaces import ( TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH, TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, TG_LLM_MODEL, TG_ONTOLOGY, TG_REIFIES, + # Query-time provenance predicates + TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_CONTENT, + TG_DOCUMENT, ) -from . uris import activity_uri, agent_uri +from . uris import activity_uri, agent_uri, edge_selection_uri def _iri(uri: str) -> Term: @@ -252,3 +255,177 @@ def triple_provenance_triples( triples.append(_triple(act_uri, TG_ONTOLOGY, _iri(ontology_uri))) return triples + + +# Query-time provenance triple builders + +def query_session_triples( + session_uri: str, + query: str, + timestamp: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for a query session activity. + + Creates: + - Activity declaration for the query session + - Query text and timestamp + + Args: + session_uri: URI of the session (from query_session_uri) + query: The user's query text + timestamp: ISO timestamp (defaults to now) + + Returns: + List of Triple objects + """ + if timestamp is None: + timestamp = datetime.utcnow().isoformat() + "Z" + + return [ + _triple(session_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + _triple(session_uri, RDFS_LABEL, _literal("GraphRAG query session")), + _triple(session_uri, PROV_STARTED_AT_TIME, _literal(timestamp)), + _triple(session_uri, TG_QUERY, _literal(query)), + ] + + +def retrieval_triples( + retrieval_uri: str, + session_uri: str, + edge_count: int, +) -> List[Triple]: + """ + Build triples for a retrieval entity (all edges retrieved from subgraph). + + Creates: + - Entity declaration for retrieval + - wasGeneratedBy link to session + - Edge count metadata + + Args: + retrieval_uri: URI of the retrieval entity (from retrieval_uri) + session_uri: URI of the parent session + edge_count: Number of edges retrieved + + Returns: + List of Triple objects + """ + return [ + _triple(retrieval_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(retrieval_uri, RDFS_LABEL, _literal("Retrieved edges")), + _triple(retrieval_uri, PROV_WAS_GENERATED_BY, _iri(session_uri)), + _triple(retrieval_uri, TG_EDGE_COUNT, _literal(edge_count)), + ] + + +def _quoted_triple(s: str, p: str, o: str) -> Term: + """Create a quoted triple term (RDF-star) from string values.""" + return Term( + type=TRIPLE, + triple=Triple(s=_iri(s), p=_iri(p), o=_iri(o)) + ) + + +def selection_triples( + selection_uri: str, + retrieval_uri: str, + selected_edges_with_reasoning: List[dict], + session_id: str = "", +) -> List[Triple]: + """ + Build triples for a selection entity (selected edges with reasoning). + + Creates: + - Entity declaration for selection + - wasDerivedFrom link to retrieval + - For each selected edge: an edge selection entity with quoted triple and reasoning + + Structure: + tg:selectedEdge . + tg:edge <<

>> . + tg:reasoning "reason" . + + Args: + selection_uri: URI of the selection entity (from selection_uri) + retrieval_uri: URI of the parent retrieval entity + selected_edges_with_reasoning: List of dicts with 'edge' (s,p,o tuple) and 'reasoning' + session_id: Session UUID for generating edge selection URIs + + Returns: + List of Triple objects + """ + triples = [ + _triple(selection_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(selection_uri, RDFS_LABEL, _literal("Selected edges")), + _triple(selection_uri, PROV_WAS_DERIVED_FROM, _iri(retrieval_uri)), + ] + + # Add each selected edge with its reasoning via intermediate entity + for idx, edge_info in enumerate(selected_edges_with_reasoning): + edge = edge_info.get("edge") + reasoning = edge_info.get("reasoning", "") + + if edge: + s, p, o = edge + + # Create intermediate entity for this edge selection + edge_sel_uri = edge_selection_uri(session_id, idx) + + # Link selection to edge selection entity + triples.append( + _triple(selection_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri)) + ) + + # Attach quoted triple to edge selection entity + quoted = _quoted_triple(s, p, o) + triples.append( + Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted) + ) + + # Attach reasoning to edge selection entity + if reasoning: + triples.append( + _triple(edge_sel_uri, TG_REASONING, _literal(reasoning)) + ) + + return triples + + +def answer_triples( + answer_uri: str, + selection_uri: str, + answer_text: str = "", + document_id: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for an answer entity (final synthesis text). + + Creates: + - Entity declaration for answer + - wasDerivedFrom link to selection + - Either document reference (if document_id provided) or inline content + + Args: + answer_uri: URI of the answer entity (from answer_uri) + selection_uri: URI of the parent selection entity + answer_text: The synthesized answer text (used if no document_id) + document_id: Optional librarian document ID (preferred over inline content) + + Returns: + List of Triple objects + """ + triples = [ + _triple(answer_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(answer_uri, RDFS_LABEL, _literal("GraphRAG answer")), + _triple(answer_uri, PROV_WAS_DERIVED_FROM, _iri(selection_uri)), + ] + + if document_id: + # Store reference to document in librarian (as IRI) + triples.append(_triple(answer_uri, TG_DOCUMENT, _iri(document_id))) + elif answer_text: + # Fallback: store inline content + triples.append(_triple(answer_uri, TG_CONTENT, _literal(answer_text))) + + return triples diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index 33b00bfd..0cd5baa4 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -60,3 +60,75 @@ def statement_uri(stmt_id: str = None) -> str: def agent_uri(component_name: str) -> str: """Generate URI for a TrustGraph component agent.""" return f"{TRUSTGRAPH_BASE}/agent/{_encode_id(component_name)}" + + +# Query-time provenance URIs +# These URIs use the urn:trustgraph: namespace to distinguish query-time +# provenance from extraction-time provenance (which uses https://trustgraph.ai/) + +def query_session_uri(session_id: str = None) -> str: + """ + Generate URI for a query session activity. + + Args: + session_id: Optional UUID string. Auto-generates if not provided. + + Returns: + URN in format: urn:trustgraph:session:{uuid} + """ + if session_id is None: + session_id = str(uuid.uuid4()) + return f"urn:trustgraph:session:{session_id}" + + +def retrieval_uri(session_id: str) -> str: + """ + Generate URI for a retrieval entity (edges retrieved from subgraph). + + Args: + session_id: The session UUID (same as query_session_uri). + + Returns: + URN in format: urn:trustgraph:prov:retrieval:{uuid} + """ + return f"urn:trustgraph:prov:retrieval:{session_id}" + + +def selection_uri(session_id: str) -> str: + """ + Generate URI for a selection entity (selected edges with reasoning). + + Args: + session_id: The session UUID (same as query_session_uri). + + Returns: + URN in format: urn:trustgraph:prov:selection:{uuid} + """ + return f"urn:trustgraph:prov:selection:{session_id}" + + +def answer_uri(session_id: str) -> str: + """ + Generate URI for an answer entity (final synthesis text). + + Args: + session_id: The session UUID (same as query_session_uri). + + Returns: + URN in format: urn:trustgraph:prov:answer:{uuid} + """ + return f"urn:trustgraph:prov:answer:{session_id}" + + +def edge_selection_uri(session_id: str, edge_index: int) -> str: + """ + Generate URI for an edge selection item (links edge to reasoning). + + Args: + session_id: The session UUID. + edge_index: Index of this edge in the selection (0-based). + + Returns: + URN in format: urn:trustgraph:prov:edge:{uuid}:{index} + """ + return f"urn:trustgraph:prov:edge:{session_id}:{edge_index}" diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 4337cb9b..8f222d98 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -21,7 +21,11 @@ class GraphRagQuery: class GraphRagResponse: error: Error | None = None response: str = "" - end_of_stream: bool = False + end_of_stream: bool = False # LLM response stream complete + explain_id: str | None = None # Single explain URI (announced as created) + explain_collection: str | None = None # Collection where explain was stored + message_type: str = "" # "chunk" or "explain" + end_of_session: bool = False # Entire session complete ############################################################################ diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 5fa359ab..03c4dfdf 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -3,7 +3,11 @@ Uses the GraphRAG service to answer a question """ import argparse +import json import os +import sys +import websockets +import asyncio from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') @@ -15,11 +19,609 @@ default_triple_limit = 30 default_max_subgraph_size = 150 default_max_path_length = 2 +# Provenance predicates +TG = "https://trustgraph.ai/ns/" +TG_QUERY = TG + "query" +TG_EDGE_COUNT = TG + "edgeCount" +TG_SELECTED_EDGE = TG + "selectedEdge" +TG_EDGE = TG + "edge" +TG_REASONING = TG + "reasoning" +TG_CONTENT = TG + "content" +TG_REIFIES = TG + "reifies" +PROV = "http://www.w3.org/ns/prov#" +PROV_STARTED_AT_TIME = PROV + "startedAtTime" +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" +RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" + + +def _get_event_type(prov_id): + """Extract event type from provenance_id""" + if "session" in prov_id: + return "session" + elif "retrieval" in prov_id: + return "retrieval" + elif "selection" in prov_id: + return "selection" + elif "answer" in prov_id: + return "answer" + return "provenance" + + +def _format_provenance_details(event_type, triples): + """Format provenance details based on event type and triples""" + lines = [] + + if event_type == "session": + # Show query and timestamp + for s, p, o in triples: + if p == TG_QUERY: + lines.append(f" Query: {o}") + elif p == PROV_STARTED_AT_TIME: + lines.append(f" Time: {o}") + + elif event_type == "retrieval": + # Show edge count + for s, p, o in triples: + if p == TG_EDGE_COUNT: + lines.append(f" Edges retrieved: {o}") + + elif event_type == "selection": + # For selection, just count edge selection URIs + # The actual edge details are fetched separately via edge_selections parameter + edge_sel_uris = [] + for s, p, o in triples: + if p == TG_SELECTED_EDGE: + edge_sel_uris.append(o) + if edge_sel_uris: + lines.append(f" Selected {len(edge_sel_uris)} edge(s)") + + elif event_type == "answer": + # Show content length (not full content - it's already streamed) + for s, p, o in triples: + if p == TG_CONTENT: + lines.append(f" Answer length: {len(o)} chars") + + return lines + + +async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, debug=False): + """Query triples for a provenance node (single attempt)""" + request = { + "id": "triples-request", + "service": "triples", + "flow": flow_id, + "request": { + "s": {"t": "i", "i": prov_id}, + "user": user, + "collection": collection, + "limit": 100 + } + } + + if debug: + print(f" [debug] querying triples for s={prov_id}", file=sys.stderr) + + triples = [] + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if debug: + print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr) + + if response.get("id") != "triples-request": + continue + + if "error" in response: + if debug: + print(f" [debug] error: {response['error']}", file=sys.stderr) + break + + if "response" in response: + resp = response["response"] + # Handle triples response + # Response format: {"response": [triples...]} + # Each triple uses compact keys: "i" for iri, "v" for value, "t" for type + triple_list = resp.get("response", []) + for t in triple_list: + s = t.get("s", {}).get("i", t.get("s", {}).get("v", "")) + p = t.get("p", {}).get("i", t.get("p", {}).get("v", "")) + # Handle quoted triples (type "t") and regular values + o_term = t.get("o", {}) + if o_term.get("t") == "t": + # Quoted triple - extract s, p, o from nested structure + tr = o_term.get("tr", {}) + o = { + "s": tr.get("s", {}).get("i", ""), + "p": tr.get("p", {}).get("i", ""), + "o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")), + } + else: + o = o_term.get("i", o_term.get("v", "")) + triples.append((s, p, o)) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception: {e}", file=sys.stderr) + + if debug: + print(f" [debug] got {len(triples)} triples", file=sys.stderr) + + return triples + + +async def _query_triples(ws_url, flow_id, prov_id, user, collection, max_retries=5, retry_delay=0.2, debug=False): + """Query triples for a provenance node with retries for race condition""" + for attempt in range(max_retries): + triples = await _query_triples_once(ws_url, flow_id, prov_id, user, collection, debug) + if triples: + return triples + # Wait before retry if empty (triples may not be stored yet) + if attempt < max_retries - 1: + if debug: + print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr) + await asyncio.sleep(retry_delay) + return [] + + +async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, collection, debug=False): + """ + Query for provenance of an edge (s, p, o) in the knowledge graph. + + Finds statements that reify the edge via tg:reifies, then follows + prov:wasDerivedFrom to find source documents. + + Returns list of source URIs (chunks, pages, documents). + """ + # Query for statements that reify this edge: ?stmt tg:reifies <> + request = { + "id": "edge-prov-request", + "service": "triples", + "flow": flow_id, + "request": { + "p": {"t": "i", "i": TG_REIFIES}, + "o": { + "t": "t", # Quoted triple type + "tr": { + "s": {"t": "i", "i": edge_s}, + "p": {"t": "i", "i": edge_p}, + "o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o}, + } + }, + "user": user, + "collection": collection, + "limit": 10 + } + } + + if debug: + print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr) + + stmt_uris = [] + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "edge-prov-request": + continue + + if "error" in response: + if debug: + print(f" [debug] error: {response['error']}", file=sys.stderr) + break + + if "response" in response: + resp = response["response"] + triple_list = resp.get("response", []) + for t in triple_list: + s = t.get("s", {}).get("i", "") + if s: + stmt_uris.append(s) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr) + + if debug: + print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr) + + # For each statement, query wasDerivedFrom to find sources + sources = [] + for stmt_uri in stmt_uris: + # Query: stmt_uri prov:wasDerivedFrom ?source + request = { + "id": "derived-from-request", + "service": "triples", + "flow": flow_id, + "request": { + "s": {"t": "i", "i": stmt_uri}, + "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, + "user": user, + "collection": collection, + "limit": 10 + } + } + + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "derived-from-request": + continue + + if "error" in response: + break + + if "response" in response: + resp = response["response"] + triple_list = resp.get("response", []) + for t in triple_list: + o = t.get("o", {}).get("i", "") + if o: + sources.append(o) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr) + + if debug: + print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr) + + return sources + + +async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=False): + """Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent.""" + request = { + "id": "parent-request", + "service": "triples", + "flow": flow_id, + "request": { + "s": {"t": "i", "i": uri}, + "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, + "user": user, + "collection": collection, + "limit": 1 + } + } + + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "parent-request": + continue + + if "error" in response: + break + + if "response" in response: + resp = response["response"] + triple_list = resp.get("response", []) + if triple_list: + return triple_list[0].get("o", {}).get("i", None) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception querying parent: {e}", file=sys.stderr) + + return None + + +async def _trace_provenance_chain(ws_url, flow_id, source_uri, user, collection, label_cache, debug=False): + """ + Trace the full provenance chain from a source URI up to the root document. + Returns a list of (uri, label) tuples from leaf to root. + """ + chain = [] + current = source_uri + max_depth = 10 # Prevent infinite loops + + for _ in range(max_depth): + if not current: + break + + # Get label for current entity + label = await _query_label(ws_url, flow_id, current, user, collection, label_cache, debug) + chain.append((current, label)) + + # Get parent + parent = await _query_derived_from(ws_url, flow_id, current, user, collection, debug) + if not parent or parent == current: + break + current = parent + + return chain + + +def _format_provenance_chain(chain): + """ + Format a provenance chain as a human-readable string. + Chain is [(uri, label), ...] from leaf to root. + """ + if not chain: + return "" + + # Show labels, from leaf to root + labels = [label for uri, label in chain] + return " → ".join(labels) + + +def _is_iri(value): + """Check if a value looks like an IRI.""" + if not isinstance(value, str): + return False + return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:") + + +async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debug=False): + """ + Query for the rdfs:label of an IRI. + Uses label_cache to avoid repeated queries. + Returns the label if found, otherwise returns the IRI. + """ + if not _is_iri(iri): + return iri + + # Check cache first + if iri in label_cache: + return label_cache[iri] + + request = { + "id": "label-request", + "service": "triples", + "flow": flow_id, + "request": { + "s": {"t": "i", "i": iri}, + "p": {"t": "i", "i": RDFS_LABEL}, + "user": user, + "collection": collection, + "limit": 1 + } + } + + label = iri # Default to IRI if no label found + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "label-request": + continue + + if "error" in response: + break + + if "response" in response: + resp = response["response"] + triple_list = resp.get("response", []) + if triple_list: + # Get the label value + o = triple_list[0].get("o", {}) + label = o.get("v", o.get("i", iri)) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr) + + # Cache the result + label_cache[iri] = label + return label + + +async def _resolve_edge_labels(ws_url, flow_id, edge_triple, user, collection, label_cache, debug=False): + """ + Resolve labels for all IRI components of an edge triple. + Returns (s_label, p_label, o_label). + """ + s = edge_triple.get("s", "?") + p = edge_triple.get("p", "?") + o = edge_triple.get("o", "?") + + s_label = await _query_label(ws_url, flow_id, s, user, collection, label_cache, debug) + p_label = await _query_label(ws_url, flow_id, p, user, collection, label_cache, debug) + o_label = await _query_label(ws_url, flow_id, o, user, collection, label_cache, debug) + + return s_label, p_label, o_label + + +async def _question_explainable( + url, flow_id, question, user, collection, entity_limit, triple_limit, + max_subgraph_size, max_path_length, token=None, debug=False +): + """Execute graph RAG with explainability - shows provenance events with details""" + # Convert HTTP URL to WebSocket URL + if url.startswith("http://"): + ws_url = url.replace("http://", "ws://", 1) + elif url.startswith("https://"): + ws_url = url.replace("https://", "wss://", 1) + else: + ws_url = f"ws://{url}" + + ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" + if token: + ws_url = f"{ws_url}?token={token}" + + # Cache for label lookups to avoid repeated queries + label_cache = {} + + request = { + "id": "cli-request", + "service": "graph-rag", + "flow": flow_id, + "request": { + "query": question, + "user": user, + "collection": collection, + "entity-limit": entity_limit, + "triple-limit": triple_limit, + "max-subgraph-size": max_subgraph_size, + "max-path-length": max_path_length, + "streaming": True + } + } + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "cli-request": + continue + + if "error" in response: + print(f"\nError: {response['error']}", file=sys.stderr) + break + + if "response" in response: + resp = response["response"] + + # Check for errors in response + if "error" in resp and resp["error"]: + err = resp["error"] + print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr) + break + + message_type = resp.get("message_type", "") + + if debug: + print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr) + + if message_type == "explain": + # Display explain event with details + explain_id = resp.get("explain_id", "") + explain_collection = resp.get("explain_collection", "explainability") + if explain_id: + event_type = _get_event_type(explain_id) + print(f"\n [{event_type}] {explain_id}", file=sys.stderr) + + # Query triples for this explain node (using explain collection from event) + triples = await _query_triples( + ws_url, flow_id, explain_id, user, explain_collection, debug=debug + ) + + # Format and display details + details = _format_provenance_details(event_type, triples) + for line in details: + print(line, file=sys.stderr) + + # For selection events, query each edge selection for details + if event_type == "selection": + for s, p, o in triples: + if debug: + print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr) + if p == TG_SELECTED_EDGE and isinstance(o, str): + if debug: + print(f" [debug] querying edge selection: {o}", file=sys.stderr) + # Query the edge selection entity (using explain collection from event) + edge_triples = await _query_triples( + ws_url, flow_id, o, user, explain_collection, debug=debug + ) + if debug: + print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr) + # Extract edge and reasoning + edge_triple = None # Store the actual triple for provenance lookup + reasoning = None + for es, ep, eo in edge_triples: + if debug: + print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr) + if ep == TG_EDGE and isinstance(eo, dict): + # eo is a quoted triple dict + edge_triple = eo + elif ep == TG_REASONING: + reasoning = eo + if edge_triple: + # Resolve labels for edge components + s_label, p_label, o_label = await _resolve_edge_labels( + ws_url, flow_id, edge_triple, user, collection, + label_cache, debug=debug + ) + print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) + if reasoning: + r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning + print(f" Reason: {r_short}", file=sys.stderr) + + # Trace edge provenance in the user's collection (not explainability) + if edge_triple: + sources = await _query_edge_provenance( + ws_url, flow_id, + edge_triple.get("s", ""), + edge_triple.get("p", ""), + edge_triple.get("o", ""), + user, collection, # Use the query collection, not explainability + debug=debug + ) + if sources: + for src in sources: + # Trace full chain from source to root document + chain = await _trace_provenance_chain( + ws_url, flow_id, src, user, collection, + label_cache, debug=debug + ) + chain_str = _format_provenance_chain(chain) + print(f" Source: {chain_str}", file=sys.stderr) + + elif message_type == "chunk" or not message_type: + # Display response chunk + chunk = resp.get("response", "") + if chunk: + print(chunk, end="", flush=True) + + # Check if session is complete + if resp.get("end_of_session"): + break + + print() # Final newline + + def question( url, flow_id, question, user, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length, streaming=True, token=None + max_subgraph_size, max_path_length, streaming=True, token=None, + explainable=False, debug=False ): + # Explainable mode uses direct websocket to capture provenance events + if explainable: + asyncio.run(_question_explainable( + url=url, + flow_id=flow_id, + question=question, + user=user, + collection=collection, + entity_limit=entity_limit, + triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, + token=token, + debug=debug + )) + return + # Create API client api = Api(url=url, token=token) @@ -138,6 +740,18 @@ def main(): help='Disable streaming (use non-streaming mode)' ) + parser.add_argument( + '-x', '--explainable', + action='store_true', + help='Show provenance events for explainability (implies streaming)' + ) + + parser.add_argument( + '--debug', + action='store_true', + help='Show debug output for troubleshooting' + ) + args = parser.parse_args() try: @@ -154,6 +768,8 @@ def main(): max_path_length=args.max_path_length, streaming=not args.no_streaming, token=args.token, + explainable=args.explainable, + debug=args.debug, ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 9cea4f48..1bb88f21 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -13,7 +13,7 @@ from .... direct.cassandra_kg import ( EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH ) from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error -from .... schema import Term, Triple, IRI, LITERAL, TRIPLE +from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK from .... base import TriplesQueryService from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config @@ -23,6 +23,36 @@ logger = logging.getLogger(__name__) default_ident = "triples-query" +def serialize_triple(triple): + """Serialize a Triple object to JSON for querying (must match storage format).""" + if triple is None: + return None + + def term_to_dict(term): + if term is None: + return None + result = {"type": term.type} + if term.type == IRI: + result["iri"] = term.iri + elif term.type == LITERAL: + result["value"] = term.value + if term.datatype: + result["datatype"] = term.datatype + if term.language: + result["language"] = term.language + elif term.type == BLANK: + result["id"] = term.id + elif term.type == TRIPLE: + result["triple"] = serialize_triple(term.triple) + return result + + return json.dumps({ + "s": term_to_dict(triple.s), + "p": term_to_dict(triple.p), + "o": term_to_dict(triple.o), + }) + + def get_term_value(term): """Extract the string value from a Term""" if term is None: @@ -31,6 +61,9 @@ def get_term_value(term): return term.iri elif term.type == LITERAL: return term.value + elif term.type == TRIPLE: + # Serialize nested triple to JSON (must match storage format) + return serialize_triple(term.triple) else: # For blank nodes or other types, use id or value return term.id or term.value @@ -66,51 +99,50 @@ def deserialize_term(term_dict): return Term(type=LITERAL, value=str(term_dict)) -def create_term(value, otype=None, dtype=None, lang=None): +def create_term(value, term_type=None, datatype=None, language=None): """ Create a Term from a string value, optionally using type metadata. Args: value: The string value - otype: Object type - 'u' (URI), 'l' (literal), 't' (triple) - dtype: XSD datatype (for literals) - lang: Language tag (for literals) + term_type: 'u' (IRI), 'l' (literal), 't' (triple) + datatype: XSD datatype for literals + language: Language tag for literals - If otype is provided, uses it to determine Term type. - Otherwise falls back to URL detection heuristic. + If term_type is provided, uses it to determine Term type. + Otherwise falls back to URL detection heuristic for object values. """ - if otype is not None: - if otype == 'u': - return Term(type=IRI, iri=value) - elif otype == 'l': - return Term( - type=LITERAL, - value=value, - datatype=dtype or "", - language=lang or "" - ) - elif otype == 't': - # Triple/reification - parse JSON and create nested Triple - try: - triple_data = json.loads(value) if isinstance(value, str) else value - if isinstance(triple_data, dict): - return Term( - type=TRIPLE, - triple=Triple( - s=deserialize_term(triple_data.get("s")), - p=deserialize_term(triple_data.get("p")), - o=deserialize_term(triple_data.get("o")), - ) + if term_type == 'u': + return Term(type=IRI, iri=value) + elif term_type == 'l': + return Term( + type=LITERAL, + value=value, + datatype=datatype or "", + language=language or "" + ) + elif term_type == 't': + # Triple/reification - parse JSON and create nested Triple + try: + triple_data = json.loads(value) if isinstance(value, str) else value + if isinstance(triple_data, dict): + return Term( + type=TRIPLE, + triple=Triple( + s=deserialize_term(triple_data.get("s")), + p=deserialize_term(triple_data.get("p")), + o=deserialize_term(triple_data.get("o")), ) - except (json.JSONDecodeError, TypeError) as e: - logger.warning(f"Failed to parse triple JSON: {e}") - # Fallback if parsing fails - return Term(type=LITERAL, value=str(value)) - else: - # Unknown otype, fall back to heuristic - pass + ) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Failed to parse triple JSON: {e}") + # Fallback if parsing fails + return Term(type=LITERAL, value=str(value)) + elif term_type is not None: + # Unknown term_type, fall back to heuristic + pass - # Heuristic fallback for backwards compatibility + # Heuristic fallback for backwards compatibility (object values only) if value.startswith("http://") or value.startswith("https://"): return Term(type=IRI, iri=value) else: @@ -176,13 +208,13 @@ class Processor(TriplesQueryService): o_val = get_term_value(query.o) g_val = query.g # Already a string or None - # Helper to extract object metadata from result row - def get_o_metadata(t): - """Extract otype/dtype/lang from result row if available""" - otype = getattr(t, 'otype', None) - dtype = getattr(t, 'dtype', None) - lang = getattr(t, 'lang', None) - return otype, dtype, lang + def get_object_metadata(row): + """Extract term type metadata from result row""" + return ( + getattr(row, 'otype', None), + getattr(row, 'dtype', None), + getattr(row, 'lang', None), + ) quads = [] @@ -197,8 +229,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((s_val, p_val, o_val, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) else: # SP specified resp = self.tg.get_sp( @@ -207,8 +239,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((s_val, p_val, t.o, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((s_val, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: # SO specified @@ -218,8 +250,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((s_val, t.p, o_val, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) else: # S only resp = self.tg.get_s( @@ -228,8 +260,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((s_val, t.p, t.o, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((s_val, t.p, t.o, g, term_type, datatype, language)) else: if p_val is not None: if o_val is not None: @@ -240,8 +272,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((t.s, p_val, o_val, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) else: # P only resp = self.tg.get_p( @@ -250,8 +282,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((t.s, p_val, t.o, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((t.s, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: # O only @@ -261,8 +293,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((t.s, t.p, o_val, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) else: # Nothing specified - get all resp = self.tg.get_all( @@ -272,16 +304,17 @@ class Processor(TriplesQueryService): for t in resp: # Note: quads_by_collection uses 'd' for graph field g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((t.s, t.p, t.o, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((t.s, t.p, t.o, g, term_type, datatype, language)) # Convert to Triple objects (with g field) - # Use otype/dtype/lang for proper Term reconstruction if available + # s and p are always IRIs in RDF + # Object uses term_type/datatype/language metadata from database triples = [ Triple( - s=create_term(q[0]), - p=create_term(q[1]), - o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]), + s=create_term(q[0], term_type='u'), + p=create_term(q[1], term_type='u'), + o=create_term(q[2], term_type=q[4], datatype=q[5], language=q[6]), g=q[3] if q[3] != DEFAULT_GRAPH else None ) for q in quads @@ -311,12 +344,13 @@ class Processor(TriplesQueryService): o_val = get_term_value(query.o) g_val = query.g - # Helper to extract object metadata from result row - def get_o_metadata(t): - otype = getattr(t, 'otype', None) - dtype = getattr(t, 'dtype', None) - lang = getattr(t, 'lang', None) - return otype, dtype, lang + def get_object_metadata(row): + """Extract term type metadata from result row""" + return ( + getattr(row, 'otype', None), + getattr(row, 'dtype', None), + getattr(row, 'lang', None), + ) # For streaming, we need to execute with fetch_size # Use the collection table for get_all queries (most common streaming case) @@ -345,12 +379,13 @@ class Processor(TriplesQueryService): break g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(row) + term_type, datatype, language = get_object_metadata(row) + # s and p are always IRIs in RDF triple = Triple( - s=create_term(row.s), - p=create_term(row.p), - o=create_term(row.o, otype=otype, dtype=dtype, lang=lang), + s=create_term(row.s, term_type='u'), + p=create_term(row.p, term_type='u'), + o=create_term(row.o, term_type=term_type, datatype=datatype, language=language), g=g if g != DEFAULT_GRAPH else None ) batch.append(triple) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 8dbeb41b..92f09ebf 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -1,11 +1,27 @@ import asyncio +import hashlib +import json import logging import time +import uuid from collections import OrderedDict +from datetime import datetime from ... schema import IRI, LITERAL +# Provenance imports +from trustgraph.provenance import ( + query_session_uri, + retrieval_uri as make_retrieval_uri, + selection_uri as make_selection_uri, + answer_uri as make_answer_uri, + query_session_triples, + retrieval_triples, + selection_triples, + answer_triples, +) + # Module logger logger = logging.getLogger(__name__) @@ -23,6 +39,12 @@ def term_to_string(term): # Fallback return term.iri or term.value or str(term) + +def edge_id(s, p, o): + """Generate an 8-character hash ID for an edge (s, p, o).""" + edge_str = f"{s}|{p}|{o}" + return hashlib.sha256(edge_str.encode()).hexdigest()[:8] + class LRUCacheWithTTL: """LRU cache with TTL for label caching @@ -258,7 +280,14 @@ class Query: return await asyncio.gather(*tasks, return_exceptions=True) async def get_labelgraph(self, query): + """ + Get subgraph with labels resolved for display. + Returns: + tuple: (labeled_edges, uri_map) where: + - labeled_edges: list of (label_s, label_p, label_o) tuples + - uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o) + """ subgraph = await self.get_subgraph(query) # Filter out label triples @@ -281,27 +310,33 @@ class Query: else: label_map[entity] = entity # Fallback to entity itself - # Apply labels to subgraph - sg2 = [] + # Apply labels to subgraph and build URI mapping + labeled_edges = [] + uri_map = {} # Maps edge_id of labeled edge -> original URI triple + for s, p, o in filtered_subgraph: labeled_triple = ( label_map.get(s, s), label_map.get(p, p), label_map.get(o, o) ) - sg2.append(labeled_triple) + labeled_edges.append(labeled_triple) - sg2 = sg2[0:self.max_subgraph_size] + # Map from labeled edge ID to original URIs + labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2]) + uri_map[labeled_eid] = (s, p, o) + + labeled_edges = labeled_edges[0:self.max_subgraph_size] if self.verbose: logger.debug("Subgraph:") - for edge in sg2: + for edge in labeled_edges: logger.debug(f" {str(edge)}") if self.verbose: logger.debug("Done.") - return sg2 + return labeled_edges, uri_map class GraphRag: """ @@ -335,11 +370,44 @@ class GraphRag: self, query, user = "trustgraph", collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, max_path_length = 2, streaming = False, chunk_callback = None, + explain_callback = None, save_answer_callback = None, ): + """ + Execute a GraphRAG query with real-time explainability tracking. + Args: + query: The query string + user: User identifier + collection: Collection identifier + entity_limit: Max entities to retrieve + triple_limit: Max triples per entity + max_subgraph_size: Max edges in subgraph + max_path_length: Max hops from seed entities + streaming: Enable streaming LLM response + chunk_callback: async def callback(chunk, end_of_stream) for streaming + explain_callback: async def callback(triples, explain_id) for real-time explainability + save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian + + Returns: + str: The synthesized answer text + """ if self.verbose: logger.debug("Constructing prompt...") + # Generate explainability URIs upfront + session_id = str(uuid.uuid4()) + session_uri = query_session_uri(session_id) + ret_uri = make_retrieval_uri(session_id) + sel_uri = make_selection_uri(session_id) + ans_uri = make_answer_uri(session_id) + + timestamp = datetime.utcnow().isoformat() + "Z" + + # Emit session explainability immediately + if explain_callback: + session_triples = query_session_triples(session_uri, query, timestamp) + await explain_callback(session_triples, session_uri) + q = Query( rag = self, user = user, collection = collection, verbose = self.verbose, entity_limit = entity_limit, @@ -348,24 +416,171 @@ class GraphRag: max_path_length = max_path_length, ) - kg = await q.get_labelgraph(query) + kg, uri_map = await q.get_labelgraph(query) + + # Emit retrieval explain after graph retrieval completes + if explain_callback: + ret_triples = retrieval_triples(ret_uri, session_uri, len(kg)) + await explain_callback(ret_triples, ret_uri) if self.verbose: logger.debug("Invoking LLM...") logger.debug(f"Knowledge graph: {kg}") logger.debug(f"Query: {query}") - if streaming and chunk_callback: - resp = await self.prompt_client.kg_prompt( - query, kg, - streaming=True, - chunk_callback=chunk_callback + # Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)} + # uri_map already maps edge_id -> (uri_s, uri_p, uri_o) + edge_map = {} + edges_with_ids = [] + for s, p, o in kg: + eid = edge_id(s, p, o) + edge_map[eid] = (s, p, o) + edges_with_ids.append({ + "id": eid, + "s": s, + "p": p, + "o": o + }) + + if self.verbose: + logger.debug(f"Built edge map with {len(edge_map)} edges") + + # Step 1: Edge Selection - LLM selects relevant edges with reasoning + selection_response = await self.prompt_client.prompt( + "kg-edge-selection", + variables={ + "query": query, + "knowledge": edges_with_ids + } + ) + + if self.verbose: + logger.debug(f"Edge selection response: {selection_response}") + + # Parse response to get selected edge IDs and reasoning + # Response can be a string (JSONL) or a list (JSON array) + selected_ids = set() + selected_edges_with_reasoning = [] # For explain + + if isinstance(selection_response, list): + # JSON array response + for obj in selection_response: + if isinstance(obj, dict) and "id" in obj: + selected_ids.add(obj["id"]) + # Capture original URI edge (not labels) and reasoning for explain + eid = obj["id"] + if eid in uri_map: + # Use original URIs for provenance tracing + uri_s, uri_p, uri_o = uri_map[eid] + selected_edges_with_reasoning.append({ + "edge": (uri_s, uri_p, uri_o), + "reasoning": obj.get("reasoning", ""), + }) + elif isinstance(selection_response, str): + # JSONL string response + for line in selection_response.strip().split('\n'): + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + if "id" in obj: + selected_ids.add(obj["id"]) + # Capture original URI edge (not labels) and reasoning for explain + eid = obj["id"] + if eid in uri_map: + # Use original URIs for provenance tracing + uri_s, uri_p, uri_o = uri_map[eid] + selected_edges_with_reasoning.append({ + "edge": (uri_s, uri_p, uri_o), + "reasoning": obj.get("reasoning", ""), + }) + except json.JSONDecodeError: + logger.warning(f"Failed to parse edge selection line: {line}") + continue + + if self.verbose: + logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}") + + # Filter to selected edges + selected_edges = [] + for eid in selected_ids: + if eid in edge_map: + selected_edges.append(edge_map[eid]) + + if self.verbose: + logger.debug(f"Filtered to {len(selected_edges)} edges") + + # Emit selection explain after edge selection completes + if explain_callback: + sel_triples = selection_triples( + sel_uri, ret_uri, selected_edges_with_reasoning, session_id ) + await explain_callback(sel_triples, sel_uri) + + # Step 2: Synthesis - LLM generates answer from selected edges only + selected_edge_dicts = [ + {"s": s, "p": p, "o": o} + for s, p, o in selected_edges + ] + if streaming and chunk_callback: + # Accumulate chunks for answer storage while forwarding to callback + accumulated_chunks = [] + + async def accumulating_callback(chunk, end_of_stream): + accumulated_chunks.append(chunk) + await chunk_callback(chunk, end_of_stream) + + await self.prompt_client.prompt( + "kg-synthesis", + variables={ + "query": query, + "knowledge": selected_edge_dicts + }, + streaming=True, + chunk_callback=accumulating_callback + ) + # Combine all chunks into full response + resp = "".join(accumulated_chunks) else: - resp = await self.prompt_client.kg_prompt(query, kg) + resp = await self.prompt_client.prompt( + "kg-synthesis", + variables={ + "query": query, + "knowledge": selected_edge_dicts + } + ) if self.verbose: logger.debug("Query processing complete") + # Emit answer explain after synthesis completes + if explain_callback: + answer_doc_id = None + answer_text = resp if resp else "" + + # Save answer to librarian if callback provided + if save_answer_callback and answer_text: + # Generate document ID as URN matching query-time provenance format + answer_doc_id = f"urn:trustgraph:answer:{session_id}" + try: + await save_answer_callback(answer_doc_id, answer_text) + if self.verbose: + logger.debug(f"Saved answer to librarian: {answer_doc_id}") + except Exception as e: + logger.warning(f"Failed to save answer to librarian: {e}") + answer_doc_id = None # Fall back to inline content + + # Generate triples with document reference or inline content + ans_triples = answer_triples( + ans_uri, sel_uri, + answer_text="" if answer_doc_id else answer_text, + document_id=answer_doc_id, + ) + await explain_callback(ans_triples, ans_uri) + + if self.verbose: + logger.debug(f"Emitted explain for session {session_id}") + return resp diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index d8bfbddb..f7b9054f 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -4,18 +4,28 @@ Simple RAG service, performs query using graph RAG an LLM. Input is query, output is response. """ +import asyncio +import base64 import logging +import uuid + from ... schema import GraphRagQuery, GraphRagResponse, Error +from ... schema import Triples, Metadata +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue from . graph_rag import GraphRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec +from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) default_ident = "graph-rag" default_concurrency = 1 +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue class Processor(FlowProcessor): @@ -28,6 +38,7 @@ class Processor(FlowProcessor): triple_limit = params.get("triple_limit", 30) max_subgraph_size = params.get("max_subgraph_size", 150) max_path_length = params.get("max_path_length", 2) + explainability_collection = params.get("explainability_collection", "explainability") super(Processor, self).__init__( **params | { @@ -37,6 +48,7 @@ class Processor(FlowProcessor): "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, "max_path_length": max_path_length, + "explainability_collection": explainability_collection, } ) @@ -44,6 +56,7 @@ class Processor(FlowProcessor): self.default_triple_limit = triple_limit self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length + self.explainability_collection = explainability_collection # CRITICAL SECURITY: NEVER share data between users or collections # Each user/collection combination MUST have isolated data access @@ -93,10 +106,163 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + ProducerSpec( + name = "explainability", + schema = Triples, + ) + ) + + # Librarian client for storing answer content + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request" + ) + + self.librarian_request_producer = Producer( + backend=self.pubsub, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self.on_librarian_response, + metrics=librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_librarian_requests = {} + + logger.info("Graph RAG service initialized") + + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id and request_id in self.pending_librarian_requests: + future = self.pending_librarian_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + """ + Save answer content to the librarian. + + Args: + doc_id: ID for the answer document + user: User ID + content: Answer text content + title: Optional title + timeout: Request timeout in seconds + + Returns: + The document ID on success + """ + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or "GraphRAG Answer", + document_type="answer", + ) + + request = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_librarian_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving answer: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_librarian_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving answer document {doc_id}") + async def on_request(self, msg, consumer, flow): try: + v = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + logger.info(f"Handling input {id}...") + + # Track explainability refs for end_of_session signaling + explainability_refs_emitted = [] + + # Real-time explainability callback - emits triples and IDs as they're generated + async def send_explainability(triples, explain_id): + # Send triples to explainability queue + await flow("explainability").send(Triples( + metadata=Metadata( + id=explain_id, + metadata=[], + user=v.user, + collection=self.explainability_collection, + ), + triples=triples, + )) + + # Send explain ID and collection to response queue + await flow("response").send( + GraphRagResponse( + message_type="explain", + explain_id=explain_id, + explain_collection=self.explainability_collection, + ), + properties={"id": id} + ) + + explainability_refs_emitted.append(explain_id) + # CRITICAL SECURITY: Create new GraphRag instance per request # This ensures proper isolation between users and collections # Flow clients are request-scoped and must not be shared @@ -108,13 +274,6 @@ class Processor(FlowProcessor): verbose=True, ) - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - logger.info(f"Handling input {id}...") - if v.entity_limit: entity_limit = v.entity_limit else: @@ -135,6 +294,15 @@ class Processor(FlowProcessor): else: max_path_length = self.default_max_path_length + # Callback to save answer content to librarian + async def save_answer(doc_id, answer_text): + await self.save_answer_content( + doc_id=doc_id, + user=v.user, + content=answer_text, + title=f"GraphRAG Answer: {v.query[:50]}...", + ) + # Check if streaming is requested if v.streaming: # Define async callback for streaming chunks @@ -142,6 +310,7 @@ class Processor(FlowProcessor): async def send_chunk(chunk, end_of_stream): await flow("response").send( GraphRagResponse( + message_type="chunk", response=chunk, end_of_stream=end_of_stream, error=None @@ -149,34 +318,50 @@ class Processor(FlowProcessor): properties={"id": id} ) - # Query with streaming enabled - # All chunks (including final one with end_of_stream=True) are sent via callback - await rag.query( + # Query with streaming and real-time explain + response = await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, streaming = True, chunk_callback = send_chunk, + explain_callback = send_explainability, + save_answer_callback = save_answer, ) + else: - # Non-streaming path (existing behavior) + # Non-streaming path with real-time explain response = await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + explain_callback = send_explainability, + save_answer_callback = save_answer, ) + # Send chunk with response await flow("response").send( GraphRagResponse( - response = response, - end_of_stream = True, - error = None + message_type="chunk", + response=response, + end_of_stream=True, + error=None, ), - properties = {"id": id} + properties={"id": id} ) + # Send final message to close session + await flow("response").send( + GraphRagResponse( + message_type="chunk", + response="", + end_of_session=True, + ), + properties={"id": id} + ) + logger.info("Request processing complete") except Exception as e: @@ -185,22 +370,18 @@ class Processor(FlowProcessor): logger.debug("Sending error response...") - # Send error response with end_of_stream flag if streaming was requested - error_response = GraphRagResponse( - response = None, - error = Error( - type = "graph-rag-error", - message = str(e), - ), - ) - - # If streaming was requested, indicate stream end - if v.streaming: - error_response.end_of_stream = True - + # Send error response and close session await flow("response").send( - error_response, - properties = {"id": id} + GraphRagResponse( + message_type="chunk", + error=Error( + type="graph-rag-error", + message=str(e), + ), + end_of_stream=True, + end_of_session=True, + ), + properties={"id": id} ) @staticmethod @@ -243,6 +424,12 @@ class Processor(FlowProcessor): help=f'Default max path length (default: 2)' ) + parser.add_argument( + '--explainability-collection', + default='explainability', + help=f'Collection for storing explainability triples (default: explainability)' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index 2c84d21c..e551ed5d 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -443,13 +443,22 @@ class McpServer: gen = manager.request("graph-rag", request_data, flow_id) + text_chunks = [] async for response in gen: + # Handle new message format with message_type + message_type = response.get("message_type", "chunk") - # Extract vectors from response - text = response.get("response", "") - break - - return GraphRagResponse(response=text) + # Only collect text from chunk messages + if message_type == "chunk": + chunk_text = response.get("response", "") + if chunk_text: + text_chunks.append(chunk_text) + + # Check if session is complete + if response.get("end_of_session"): + break + + return GraphRagResponse(response="".join(text_chunks)) async def agent( self,