mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-29 16:39:37 +02:00
GraphRAG Query-Time Explainability (#677)
Implements full explainability pipeline for GraphRAG queries, enabling
traceability from answers back to source documents.
Renamed throughout for clarity:
- provenance_callback → explain_callback
- provenance_id → explain_id
- provenance_collection → explain_collection
- message_type "provenance" → "explain"
- Queue name "provenance" → "explainability"
GraphRAG queries now emit explainability events as they execute:
1. Session - query text and timestamp
2. Retrieval - edges retrieved from subgraph
3. Selection - selected edges with LLM reasoning (JSONL with id +
reasoning)
4. Answer - reference to synthesized response
Events stream via explain_callback during query(), enabling
real-time UX.
- Answers stored in librarian service (not inline in graph - too large)
- Document ID as URN: urn:trustgraph:answer:{session_id}
- Graph stores tg:document reference (IRI) to librarian document
- Added librarian producer/consumer to graph-rag service
- get_labelgraph() now returns (labeled_edges, uri_map)
- uri_map maps edge_id(label_s, label_p, label_o) →
(uri_s, uri_p, uri_o)
- Explainability data stores original URIs, not labels
- Enables tracing edges back to reifying statements via tg:reifies
- Added serialize_triple() to query service (matches storage format)
- get_term_value() now handles TRIPLE type terms
- Enables querying by quoted triple in object position:
?stmt tg:reifies <<s p o>>
- Displays real-time explainability events during query
- Resolves rdfs:label for edge components (s, p, o)
- Traces source chain via prov:wasDerivedFrom to root document
- Output: "Source: Chunk 1 → Page 2 → Document Title"
- Label caching to avoid repeated queries
GraphRagResponse:
- explain_id: str | None
- explain_collection: str | None
- message_type: str ("chunk" or "explain")
- end_of_session: bool
trustgraph-base/trustgraph/provenance/:
- namespaces.py - Added TG_DOCUMENT predicate
- triples.py - answer_triples() supports document_id reference
- uris.py - Added edge_selection_uri()
trustgraph-base/trustgraph/schema/services/retrieval.py:
- GraphRagResponse with explain_id, explain_collection, end_of_session
trustgraph-flow/trustgraph/retrieval/graph_rag/:
- graph_rag.py - URI preservation, streaming answer accumulation
- rag.py - Librarian integration, real-time explain emission
trustgraph-flow/trustgraph/query/triples/cassandra/service.py:
- Quoted triple serialization for query matching
trustgraph-cli/trustgraph/cli/invoke_graph_rag.py:
- Full explainability display with label resolution and source tracing
This commit is contained in:
parent
d2d71f859d
commit
7a6197d8c3
24 changed files with 2001 additions and 323 deletions
|
|
@ -83,13 +83,25 @@ class TestGraphRagIntegration:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
"""Mock prompt client that generates realistic responses for two-step process"""
|
||||
client = AsyncMock()
|
||||
client.kg_prompt.return_value = (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
|
||||
# Mock responses for the two-step process:
|
||||
# 1. kg-edge-selection returns JSONL with edge IDs
|
||||
# 2. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Return empty selection (no edges selected) - valid JSONL
|
||||
return ""
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
return ""
|
||||
|
||||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -108,7 +120,7 @@ class TestGraphRagIntegration:
|
|||
async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client,
|
||||
mock_graph_embeddings_client, mock_triples_client,
|
||||
mock_prompt_client):
|
||||
"""Test complete GraphRAG pipeline from query to response"""
|
||||
"""Test complete GraphRAG pipeline from query to response with real-time provenance"""
|
||||
# Arrange
|
||||
query = "What is machine learning?"
|
||||
user = "test_user"
|
||||
|
|
@ -116,13 +128,20 @@ class TestGraphRagIntegration:
|
|||
entity_limit = 50
|
||||
triple_limit = 30
|
||||
|
||||
# Collect provenance events
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act
|
||||
result = await graph_rag.query(
|
||||
response = await graph_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit
|
||||
triple_limit=triple_limit,
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert - Verify service coordination
|
||||
|
|
@ -141,16 +160,19 @@ class TestGraphRagIntegration:
|
|||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt with knowledge graph
|
||||
mock_prompt_client.kg_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.kg_prompt.call_args
|
||||
assert call_args.args[0] == query # First arg is query
|
||||
assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples)
|
||||
# 4. Should call prompt twice (edge selection + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 2
|
||||
|
||||
# Verify final response
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
||||
# Verify provenance was emitted in real-time (4 events: session, retrieval, selection, answer)
|
||||
assert len(provenance_events) == 4
|
||||
for triples, prov_id in provenance_events:
|
||||
assert isinstance(triples, list)
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client,
|
||||
|
|
@ -206,19 +228,25 @@ class TestGraphRagIntegration:
|
|||
mock_graph_embeddings_client.query.return_value = [] # No entities found
|
||||
mock_triples_client.query_stream.return_value = [] # No triples found
|
||||
|
||||
# Collect provenance
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act
|
||||
result = await graph_rag.query(
|
||||
response = await graph_rag.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
collection="test_collection",
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should still call prompt client with empty knowledge graph
|
||||
mock_prompt_client.kg_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.kg_prompt.call_args
|
||||
assert isinstance(call_args.args[1], list) # kg should be a list
|
||||
assert result is not None
|
||||
# Should still call prompt client (twice: edge selection + synthesis)
|
||||
assert response is not None
|
||||
# Provenance should still be emitted (4 events)
|
||||
assert len(provenance_events) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):
|
||||
|
|
|
|||
|
|
@ -53,30 +53,34 @@ class TestGraphRagStreaming:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
|
||||
"""Mock prompt client with streaming support"""
|
||||
"""Mock prompt client with streaming support for two-stage GraphRAG"""
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
# Both modes return the same text
|
||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
# Full synthesis text
|
||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
chunks = []
|
||||
async for chunk in mock_streaming_llm_response():
|
||||
chunks.append(chunk)
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "kg-edge-selection":
|
||||
# Edge selection returns JSONL with IDs - simulate selecting first edge
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
chunks = []
|
||||
async for chunk in mock_streaming_llm_response():
|
||||
chunks.append(chunk)
|
||||
|
||||
# Send all chunks with end_of_stream=False except the last
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
# Send all chunks with end_of_stream=False except the last
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return full_text
|
||||
else:
|
||||
return full_text
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -93,18 +97,25 @@ class TestGraphRagStreaming:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector):
|
||||
"""Test basic GraphRAG streaming functionality"""
|
||||
"""Test basic GraphRAG streaming functionality with real-time provenance"""
|
||||
# Arrange
|
||||
query = "What is machine learning?"
|
||||
collector = streaming_chunk_collector()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
# Collect provenance events
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act - query() returns response, provenance via callback
|
||||
response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collector.collect
|
||||
chunk_callback=collector.collect,
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
@ -116,10 +127,15 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Verify full response matches concatenated chunks
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
assert response == full_from_chunks
|
||||
|
||||
# Verify content is reasonable
|
||||
assert "machine" in result.lower() or "learning" in result.lower()
|
||||
assert "machine" in response.lower() or "learning" in response.lower()
|
||||
|
||||
# Verify provenance was emitted in real-time (4 events)
|
||||
assert len(provenance_events) == 4
|
||||
for triples, prov_id in provenance_events:
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming):
|
||||
|
|
@ -130,7 +146,7 @@ class TestGraphRagStreaming:
|
|||
collection = "test_collection"
|
||||
|
||||
# Act - Non-streaming
|
||||
non_streaming_result = await graph_rag_streaming.query(
|
||||
non_streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -143,7 +159,7 @@ class TestGraphRagStreaming:
|
|||
async def collect(chunk, end_of_stream):
|
||||
streaming_chunks.append(chunk)
|
||||
|
||||
streaming_result = await graph_rag_streaming.query(
|
||||
streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -152,9 +168,9 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_result == non_streaming_result
|
||||
assert streaming_response == non_streaming_response
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_result
|
||||
assert "".join(streaming_chunks) == streaming_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||
|
|
@ -163,7 +179,7 @@ class TestGraphRagStreaming:
|
|||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -173,7 +189,7 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Assert
|
||||
assert callback.call_count > 0
|
||||
assert result is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify all callback invocations had string arguments
|
||||
for call in callback.call_args_list:
|
||||
|
|
@ -183,7 +199,7 @@ class TestGraphRagStreaming:
|
|||
async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming):
|
||||
"""Test streaming parameter without callback (should fall back to non-streaming)"""
|
||||
# Arrange & Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -192,8 +208,8 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should complete without error
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||
|
|
@ -204,7 +220,7 @@ class TestGraphRagStreaming:
|
|||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -213,7 +229,7 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should still produce streamed response
|
||||
assert result is not None
|
||||
assert response is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -44,18 +44,23 @@ class TestGraphRagStreamingProtocol:
|
|||
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
await chunk_callback("The", False)
|
||||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
else:
|
||||
return "The answer is here."
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Edge selection returns empty (no edges selected)
|
||||
return ""
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
await chunk_callback("The", False)
|
||||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
else:
|
||||
return "The answer is here."
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -327,20 +332,24 @@ class TestStreamingProtocolEdgeCases:
|
|||
# Arrange
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return ""
|
||||
else:
|
||||
return "textmore"
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
else:
|
||||
return "textmore"
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_with_empties
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
rag = GraphRag(
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
|
||||
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
prompt_client=client,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue