GraphRAG Query-Time Explainability (#677)

Implements full explainability pipeline for GraphRAG queries, enabling
traceability from answers back to source documents.

Renamed throughout for clarity:
- provenance_callback → explain_callback
- provenance_id → explain_id
- provenance_collection → explain_collection
- message_type "provenance" → "explain"
- Queue name "provenance" → "explainability"

GraphRAG queries now emit explainability events as they execute:
1. Session - query text and timestamp
2. Retrieval - edges retrieved from subgraph
3. Selection - selected edges with LLM reasoning (JSONL with id +
   reasoning)
4. Answer - reference to synthesized response

Events stream via explain_callback during query(), enabling
real-time UX.

- Answers stored in librarian service (not inline in graph - too large)
- Document ID as URN: urn:trustgraph:answer:{session_id}
- Graph stores tg:document reference (IRI) to librarian document
- Added librarian producer/consumer to graph-rag service

- get_labelgraph() now returns (labeled_edges, uri_map)
- uri_map maps edge_id(label_s, label_p, label_o) →
  (uri_s, uri_p, uri_o)
- Explainability data stores original URIs, not labels
- Enables tracing edges back to reifying statements via tg:reifies

- Added serialize_triple() to query service (matches storage format)
- get_term_value() now handles TRIPLE type terms
- Enables querying by quoted triple in object position:
  ?stmt tg:reifies <<s p o>>

- Displays real-time explainability events during query
- Resolves rdfs:label for edge components (s, p, o)
- Traces source chain via prov:wasDerivedFrom to root document
- Output: "Source: Chunk 1 → Page 2 → Document Title"
- Label caching to avoid repeated queries

GraphRagResponse:
- explain_id: str | None
- explain_collection: str | None
- message_type: str ("chunk" or "explain")
- end_of_session: bool

trustgraph-base/trustgraph/provenance/:
- namespaces.py - Added TG_DOCUMENT predicate
- triples.py - answer_triples() supports document_id reference
- uris.py - Added edge_selection_uri()

trustgraph-base/trustgraph/schema/services/retrieval.py:
- GraphRagResponse with explain_id, explain_collection, end_of_session

trustgraph-flow/trustgraph/retrieval/graph_rag/:
- graph_rag.py - URI preservation, streaming answer accumulation
- rag.py - Librarian integration, real-time explain emission

trustgraph-flow/trustgraph/query/triples/cassandra/service.py:
- Quoted triple serialization for query matching

trustgraph-cli/trustgraph/cli/invoke_graph_rag.py:
- Full explainability display with label resolution and source tracing
This commit is contained in:
cybermaggedon 2026-03-10 10:00:01 +00:00 committed by GitHub
parent d2d71f859d
commit 7a6197d8c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 2001 additions and 323 deletions

View file

@ -1,13 +1,27 @@
type: object
description: Graph RAG response
description: Graph RAG response message
properties:
message_type:
type: string
description: Type of message - "chunk" for LLM response chunks, "provenance" for provenance announcements
enum: [chunk, provenance]
example: chunk
response:
type: string
description: Generated response based on retrieved knowledge graph
description: Generated response text (for chunk messages)
example: Quantum physics and computer science intersect in quantum computing...
end-of-stream:
provenance_id:
type: string
description: Provenance node URI (for provenance messages)
example: urn:trustgraph:session:abc123
end_of_stream:
type: boolean
description: Indicates streaming is complete (streaming mode)
description: Indicates LLM response stream is complete
default: false
example: true
end_of_session:
type: boolean
description: Indicates entire session is complete (all messages sent)
default: false
example: true
error:

View file

@ -17,16 +17,18 @@ from trustgraph.messaging import TranslatorRegistry
class TestRAGTranslatorCompletionFlags:
"""Contract tests for RAG response translator completion flags"""
def test_graph_rag_translator_is_final_with_end_of_stream_true(self):
def test_graph_rag_translator_is_final_with_end_of_session_true(self):
"""
Test that GraphRagResponseTranslator returns is_final=True
when end_of_stream=True.
when end_of_session=True.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("graph-rag")
response = GraphRagResponse(
response="A small domesticated mammal.",
message_type="chunk",
end_of_stream=True,
end_of_session=True,
error=None
)
@ -34,20 +36,23 @@ class TestRAGTranslatorCompletionFlags:
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_stream=True"
assert is_final is True, "is_final must be True when end_of_session=True"
assert response_dict["response"] == "A small domesticated mammal."
assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is True
assert response_dict["message_type"] == "chunk"
def test_graph_rag_translator_is_final_with_end_of_stream_false(self):
def test_graph_rag_translator_is_final_with_end_of_session_false(self):
"""
Test that GraphRagResponseTranslator returns is_final=False
when end_of_stream=False.
when end_of_session=False (even if end_of_stream=True).
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("graph-rag")
response = GraphRagResponse(
response="Chunk 1",
message_type="chunk",
end_of_stream=False,
end_of_session=False,
error=None
)
@ -55,9 +60,55 @@ class TestRAGTranslatorCompletionFlags:
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False, "is_final must be False when end_of_stream=False"
assert is_final is False, "is_final must be False when end_of_session=False"
assert response_dict["response"] == "Chunk 1"
assert response_dict["end_of_stream"] is False
assert response_dict["end_of_session"] is False
def test_graph_rag_translator_provenance_message(self):
"""
Test that GraphRagResponseTranslator handles provenance messages.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("graph-rag")
response = GraphRagResponse(
response="",
message_type="explain",
explain_id="urn:trustgraph:session:abc123",
end_of_stream=False,
end_of_session=False,
error=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False
assert response_dict["message_type"] == "explain"
assert response_dict["explain_id"] == "urn:trustgraph:session:abc123"
def test_graph_rag_translator_end_of_stream_not_final(self):
"""
Test that end_of_stream=True alone does NOT make is_final=True.
The session continues with provenance messages after LLM stream completes.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("graph-rag")
response = GraphRagResponse(
response="Final chunk",
message_type="chunk",
end_of_stream=True,
end_of_session=False, # Session continues with provenance
error=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is False
def test_document_rag_translator_is_final_with_end_of_stream_true(self):
"""

View file

@ -83,13 +83,25 @@ class TestGraphRagIntegration:
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
"""Mock prompt client that generates realistic responses for two-step process"""
client = AsyncMock()
client.kg_prompt.return_value = (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
# Mock responses for the two-step process:
# 1. kg-edge-selection returns JSONL with edge IDs
# 2. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
# Return empty selection (no edges selected) - valid JSONL
return ""
elif prompt_name == "kg-synthesis":
return (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
return ""
client.prompt.side_effect = mock_prompt
return client
@pytest.fixture
@ -108,7 +120,7 @@ class TestGraphRagIntegration:
async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client,
mock_graph_embeddings_client, mock_triples_client,
mock_prompt_client):
"""Test complete GraphRAG pipeline from query to response"""
"""Test complete GraphRAG pipeline from query to response with real-time provenance"""
# Arrange
query = "What is machine learning?"
user = "test_user"
@ -116,13 +128,20 @@ class TestGraphRagIntegration:
entity_limit = 50
triple_limit = 30
# Collect provenance events
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
# Act
result = await graph_rag.query(
response = await graph_rag.query(
query=query,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit
triple_limit=triple_limit,
explain_callback=collect_provenance
)
# Assert - Verify service coordination
@ -141,16 +160,19 @@ class TestGraphRagIntegration:
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt with knowledge graph
mock_prompt_client.kg_prompt.assert_called_once()
call_args = mock_prompt_client.kg_prompt.call_args
assert call_args.args[0] == query # First arg is query
assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples)
# 4. Should call prompt twice (edge selection + synthesis)
assert mock_prompt_client.prompt.call_count == 2
# Verify final response
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()
# Verify provenance was emitted in real-time (4 events: session, retrieval, selection, answer)
assert len(provenance_events) == 4
for triples, prov_id in provenance_events:
assert isinstance(triples, list)
assert prov_id.startswith("urn:trustgraph:")
@pytest.mark.asyncio
async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client,
@ -206,19 +228,25 @@ class TestGraphRagIntegration:
mock_graph_embeddings_client.query.return_value = [] # No entities found
mock_triples_client.query_stream.return_value = [] # No triples found
# Collect provenance
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
# Act
result = await graph_rag.query(
response = await graph_rag.query(
query="unknown topic",
user="test_user",
collection="test_collection"
collection="test_collection",
explain_callback=collect_provenance
)
# Assert
# Should still call prompt client with empty knowledge graph
mock_prompt_client.kg_prompt.assert_called_once()
call_args = mock_prompt_client.kg_prompt.call_args
assert isinstance(call_args.args[1], list) # kg should be a list
assert result is not None
# Should still call prompt client (twice: edge selection + synthesis)
assert response is not None
# Provenance should still be emitted (4 events)
assert len(provenance_events) == 4
@pytest.mark.asyncio
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):

View file

@ -53,30 +53,34 @@ class TestGraphRagStreaming:
@pytest.fixture
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
"""Mock prompt client with streaming support"""
"""Mock prompt client with streaming support for two-stage GraphRAG"""
client = AsyncMock()
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
# Both modes return the same text
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
# Full synthesis text
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
chunks = []
async for chunk in mock_streaming_llm_response():
chunks.append(chunk)
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "kg-edge-selection":
# Edge selection returns JSONL with IDs - simulate selecting first edge
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
chunks = []
async for chunk in mock_streaming_llm_response():
chunks.append(chunk)
# Send all chunks with end_of_stream=False except the last
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
# Send all chunks with end_of_stream=False except the last
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
else:
# Non-streaming response - same text
return full_text
return full_text
else:
return full_text
return ""
client.kg_prompt.side_effect = kg_prompt_side_effect
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
@ -93,18 +97,25 @@ class TestGraphRagStreaming:
@pytest.mark.asyncio
async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector):
"""Test basic GraphRAG streaming functionality"""
"""Test basic GraphRAG streaming functionality with real-time provenance"""
# Arrange
query = "What is machine learning?"
collector = streaming_chunk_collector()
# Act
result = await graph_rag_streaming.query(
# Collect provenance events
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
# Act - query() returns response, provenance via callback
response = await graph_rag_streaming.query(
query=query,
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collector.collect
chunk_callback=collector.collect,
explain_callback=collect_provenance
)
# Assert
@ -116,10 +127,15 @@ class TestGraphRagStreaming:
# Verify full response matches concatenated chunks
full_from_chunks = collector.get_full_text()
assert result == full_from_chunks
assert response == full_from_chunks
# Verify content is reasonable
assert "machine" in result.lower() or "learning" in result.lower()
assert "machine" in response.lower() or "learning" in response.lower()
# Verify provenance was emitted in real-time (4 events)
assert len(provenance_events) == 4
for triples, prov_id in provenance_events:
assert prov_id.startswith("urn:trustgraph:")
@pytest.mark.asyncio
async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming):
@ -130,7 +146,7 @@ class TestGraphRagStreaming:
collection = "test_collection"
# Act - Non-streaming
non_streaming_result = await graph_rag_streaming.query(
non_streaming_response = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
@ -143,7 +159,7 @@ class TestGraphRagStreaming:
async def collect(chunk, end_of_stream):
streaming_chunks.append(chunk)
streaming_result = await graph_rag_streaming.query(
streaming_response = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
@ -152,9 +168,9 @@ class TestGraphRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_result == non_streaming_result
assert streaming_response == non_streaming_response
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result
assert "".join(streaming_chunks) == streaming_response
@pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
@ -163,7 +179,7 @@ class TestGraphRagStreaming:
callback = AsyncMock()
# Act
result = await graph_rag_streaming.query(
response = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
@ -173,7 +189,7 @@ class TestGraphRagStreaming:
# Assert
assert callback.call_count > 0
assert result is not None
assert response is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
@ -183,7 +199,7 @@ class TestGraphRagStreaming:
async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming):
"""Test streaming parameter without callback (should fall back to non-streaming)"""
# Arrange & Act
result = await graph_rag_streaming.query(
response = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
@ -192,8 +208,8 @@ class TestGraphRagStreaming:
)
# Assert - Should complete without error
assert result is not None
assert isinstance(result, str)
assert response is not None
assert isinstance(response, str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
@ -204,7 +220,7 @@ class TestGraphRagStreaming:
callback = AsyncMock()
# Act
result = await graph_rag_streaming.query(
response = await graph_rag_streaming.query(
query="unknown topic",
user="test_user",
collection="test_collection",
@ -213,7 +229,7 @@ class TestGraphRagStreaming:
)
# Assert - Should still produce streamed response
assert result is not None
assert response is not None
assert callback.call_count > 0
@pytest.mark.asyncio

View file

@ -44,18 +44,23 @@ class TestGraphRagStreamingProtocol:
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
client = AsyncMock()
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
await chunk_callback("The", False)
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
else:
return "The answer is here."
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
# Edge selection returns empty (no edges selected)
return ""
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
await chunk_callback("The", False)
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
else:
return "The answer is here."
return ""
client.kg_prompt.side_effect = kg_prompt_side_effect
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
@ -327,20 +332,24 @@ class TestStreamingProtocolEdgeCases:
# Arrange
client = AsyncMock()
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return ""
else:
return "textmore"
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
return ""
else:
return "textmore"
return ""
client.kg_prompt.side_effect = kg_prompt_with_empties
client.prompt.side_effect = prompt_with_empties
rag = GraphRag(
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
prompt_client=client,

View file

@ -547,21 +547,21 @@ class TestServiceHelperFunctions:
"""Test cases for helper functions in service.py"""
def test_create_term_with_uri_otype(self):
"""Test create_term creates IRI Term for otype='u'"""
"""Test create_term creates IRI Term for term_type='u'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
term = create_term('http://example.org/Alice', otype='u')
term = create_term('http://example.org/Alice', term_type='u')
assert term.type == IRI
assert term.iri == 'http://example.org/Alice'
def test_create_term_with_literal_otype(self):
"""Test create_term creates LITERAL Term for otype='l'"""
"""Test create_term creates LITERAL Term for term_type='l'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import LITERAL
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
term = create_term('Alice Smith', term_type='l', datatype='xsd:string', language='en')
assert term.type == LITERAL
assert term.value == 'Alice Smith'
@ -569,7 +569,7 @@ class TestServiceHelperFunctions:
assert term.language == 'en'
def test_create_term_with_triple_otype(self):
"""Test create_term creates TRIPLE Term for otype='t' with valid JSON"""
"""Test create_term creates TRIPLE Term for term_type='t' with valid JSON"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import TRIPLE, IRI
import json
@ -581,7 +581,7 @@ class TestServiceHelperFunctions:
"o": {"type": "i", "iri": "http://example.org/Bob"},
})
term = create_term(triple_json, otype='t')
term = create_term(triple_json, term_type='t')
assert term.type == TRIPLE
assert term.triple is not None

View file

@ -96,20 +96,21 @@ class TestGraphRagResponseTranslator:
assert is_final is False
assert result["end_of_stream"] is False
# Test final chunk with empty content
# Test final message with end_of_session=True
final_response = GraphRagResponse(
response="",
end_of_stream=True,
end_of_session=True,
error=None
)
# Act
result, is_final = translator.from_response_with_completion(final_response)
# Assert
# Assert - is_final is based on end_of_session, not end_of_stream
assert is_final is True
assert result["response"] == ""
assert result["end_of_stream"] is True
assert result["end_of_session"] is True
class TestDocumentRagResponseTranslator:

View file

@ -118,8 +118,8 @@ class TestCassandraQueryProcessor:
# Verify result contains the queried triple
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'test_predicate'
assert result[0].o.value == 'test_object'
def test_processor_initialization_with_defaults(self):
@ -182,8 +182,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@ -219,8 +219,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'result_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@ -256,8 +256,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@ -293,8 +293,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@ -331,8 +331,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
assert result[0].s.value == 'all_subject'
assert result[0].p.value == 'all_predicate'
assert result[0].s.iri == 'all_subject'
assert result[0].p.iri == 'all_predicate'
assert result[0].o.value == 'all_object'
def test_add_args_method(self):
@ -637,8 +637,8 @@ class TestCassandraQueryPerformanceOptimizations:
)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'test_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@ -678,8 +678,8 @@ class TestCassandraQueryPerformanceOptimizations:
)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@ -802,7 +802,7 @@ class TestCassandraQueryPerformanceOptimizations:
# Verify all results were returned
assert len(result) == 5
for i, triple in enumerate(result):
assert triple.s.value == f'subject_{i}' # Mock returns literal values
assert triple.s.iri == f'subject_{i}' # Mock returns literal values
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
assert triple.p.type == IRI
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri

View file

@ -540,41 +540,68 @@ class TestQuery:
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
# Call get_labelgraph
result = await query.get_labelgraph("test query")
labeled_edges, uri_map = await query.get_labelgraph("test query")
# Verify get_subgraph was called
query.get_subgraph.assert_called_once_with("test query")
# Verify label triples are filtered out
assert len(result) == 2 # Label triple should be excluded
assert len(labeled_edges) == 2 # Label triple should be excluded
# Verify maybe_label was called for non-label triples
expected_calls = [
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
]
assert query.maybe_label.call_count == 6
# Verify result contains human-readable labels
expected_result = [
expected_edges = [
("Human Entity One", "Human Predicate One", "Human Object One"),
("Human Entity Three", "Human Predicate Three", "Human Object Three")
]
assert result == expected_result
assert labeled_edges == expected_edges
# Verify uri_map maps labeled edges back to original URIs
assert len(uri_map) == 2
@pytest.mark.asyncio
async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline"""
"""Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance"""
import json
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock()
# Mock prompt client response
# Mock prompt client responses for two-step process
expected_response = "This is the RAG response"
mock_prompt_client.kg_prompt.return_value = expected_response
test_labelgraph = [("Subject", "Predicate", "Object")]
# Compute the edge ID for the test edge
test_edge_id = edge_id("Subject", "Predicate", "Object")
# Create uri_map for the test edge (maps labeled edge ID to original URIs)
test_uri_map = {
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
}
# Mock edge selection response (JSONL format)
edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"})
# Configure prompt mock to return different responses based on prompt name
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return edge_selection_response
elif prompt_name == "kg-synthesis":
return expected_response
return ""
mock_prompt_client.prompt = mock_prompt
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
@ -583,39 +610,55 @@ class TestQuery:
triples_client=mock_triples_client,
verbose=False
)
# Mock the Query class behavior by patching get_labelgraph
test_labelgraph = [("Subject", "Predicate", "Object")]
# We need to patch the Query class's get_labelgraph method
original_query_init = Query.__init__
original_get_labelgraph = Query.get_labelgraph
def mock_query_init(self, *args, **kwargs):
original_query_init(self, *args, **kwargs)
async def mock_get_labelgraph(self, query_text):
return test_labelgraph
return test_labelgraph, test_uri_map
Query.__init__ = mock_query_init
Query.get_labelgraph = mock_get_labelgraph
# Collect provenance emitted via callback
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
try:
# Call GraphRag.query
result = await graph_rag.query(
# Call GraphRag.query with provenance callback
response = await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
entity_limit=25,
triple_limit=15
triple_limit=15,
explain_callback=collect_provenance
)
# Verify prompt client was called with knowledge graph and query
mock_prompt_client.kg_prompt.assert_called_once_with("test query", test_labelgraph)
# Verify result
assert result == expected_response
# Verify response text
assert response == expected_response
# Verify provenance was emitted incrementally (4 events: session, retrieval, selection, answer)
assert len(provenance_events) == 4
# Verify each event has triples and a URN
for triples, prov_id in provenance_events:
assert isinstance(triples, list)
assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:")
# Verify order: session, retrieval, selection, answer
assert "session" in provenance_events[0][1]
assert "retrieval" in provenance_events[1][1]
assert "selection" in provenance_events[2][1]
assert "answer" in provenance_events[3][1]
finally:
# Restore original methods
Query.__init__ = original_query_init

View file

@ -1,6 +1,7 @@
"""
Unit tests for GraphRAG service non-streaming mode.
Tests that end_of_stream flag is correctly set in non-streaming responses.
Unit tests for GraphRAG service message format.
Tests the new message protocol with message_type, explain_id, and end_of_session.
Real-time explainability emission via callback.
"""
import pytest
@ -11,16 +12,14 @@ from trustgraph.schema import GraphRagQuery, GraphRagResponse
class TestGraphRagService:
"""Test GraphRAG service non-streaming behavior"""
"""Test GraphRAG service message protocol"""
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
@pytest.mark.asyncio
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class):
async def test_non_streaming_sends_chunk_then_provenance_messages(self, mock_graph_rag_class):
"""
Test that non-streaming mode sets end_of_stream=True in response.
This is a regression test for the bug where non-streaming responses
didn't set end_of_stream, causing clients to hang waiting for more data.
Test that non-streaming mode sends real-time provenance messages
followed by chunk message with response.
"""
# Setup processor
processor = Processor(
@ -32,10 +31,22 @@ class TestGraphRagService:
max_path_length=2
)
# Setup mock GraphRag instance
# Setup mock GraphRag instance that calls explain_callback
mock_rag_instance = AsyncMock()
mock_graph_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "A small domesticated mammal."
# Mock query() to call the explain_callback with each provenance event
async def mock_query(**kwargs):
explain_callback = kwargs.get('explain_callback')
if explain_callback:
# Simulate real-time provenance emission
await explain_callback([], "urn:trustgraph:session:test")
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
await explain_callback([], "urn:trustgraph:prov:selection:test")
await explain_callback([], "urn:trustgraph:prov:answer:test")
return "A small domesticated mammal."
mock_rag_instance.query.side_effect = mock_query
# Setup message with non-streaming request
msg = MagicMock()
@ -47,7 +58,7 @@ class TestGraphRagService:
triple_limit=30,
max_subgraph_size=150,
max_path_length=2,
streaming=False # Non-streaming mode
streaming=False
)
msg.properties.return_value = {"id": "test-id"}
@ -55,30 +66,48 @@ class TestGraphRagService:
consumer = MagicMock()
flow = MagicMock()
# Mock flow to return AsyncMock for clients and response producer
mock_producer = AsyncMock()
mock_response_producer = AsyncMock()
mock_provenance_producer = AsyncMock()
def flow_router(service_name):
if service_name == "response":
return mock_producer
return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients
return mock_response_producer
elif service_name == "explainability":
return mock_provenance_producer
return AsyncMock()
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: response was sent with end_of_stream=True
mock_producer.send.assert_called_once()
sent_response = mock_producer.send.call_args[0][0]
assert isinstance(sent_response, GraphRagResponse)
assert sent_response.response == "A small domesticated mammal."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.error is None
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
assert mock_response_producer.send.call_count == 6
# First 4 messages are explain (emitted in real-time during query)
for i in range(4):
prov_msg = mock_response_producer.send.call_args_list[i][0][0]
assert prov_msg.message_type == "explain"
assert prov_msg.explain_id is not None
# 5th message is chunk with response
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "A small domesticated mammal."
assert chunk_msg.end_of_stream is True
# 6th message is empty chunk with end_of_session=True
close_msg = mock_response_producer.send.call_args_list[5][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
# Verify provenance triples were sent to provenance queue
assert mock_provenance_producer.send.call_count == 4
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
@pytest.mark.asyncio
async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class):
async def test_error_response_closes_session(self, mock_graph_rag_class):
"""
Test that error responses in non-streaming mode set end_of_stream=True.
Test that error responses set end_of_session=True.
"""
# Setup processor
processor = Processor(
@ -105,7 +134,7 @@ class TestGraphRagService:
triple_limit=30,
max_subgraph_size=150,
max_path_length=2,
streaming=False # Non-streaming mode
streaming=False
)
msg.properties.return_value = {"id": "test-id"}
@ -113,22 +142,93 @@ class TestGraphRagService:
consumer = MagicMock()
flow = MagicMock()
mock_producer = AsyncMock()
mock_response_producer = AsyncMock()
mock_provenance_producer = AsyncMock()
def flow_router(service_name):
if service_name == "response":
return mock_producer
return mock_response_producer
elif service_name == "explainability":
return mock_provenance_producer
return AsyncMock()
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: error response was sent without end_of_stream (not streaming mode)
mock_producer.send.assert_called_once()
sent_response = mock_producer.send.call_args[0][0]
# Verify: error response was sent with session closed
mock_response_producer.send.assert_called_once()
sent_response = mock_response_producer.send.call_args[0][0]
assert isinstance(sent_response, GraphRagResponse)
assert sent_response.response is None
assert sent_response.message_type == "chunk"
assert sent_response.error is not None
assert sent_response.error.message == "Test error"
# Note: error responses in non-streaming mode don't set end_of_stream
# because streaming was never started
assert sent_response.end_of_stream is True
assert sent_response.end_of_session is True
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
@pytest.mark.asyncio
async def test_no_provenance_sends_empty_chunk_to_close(self, mock_graph_rag_class):
"""
Test that when no provenance callback is invoked, an empty chunk closes the session.
"""
# Setup processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
entity_limit=50,
triple_limit=30,
max_subgraph_size=150,
max_path_length=2
)
# Setup mock GraphRag instance that doesn't call provenance callback
mock_rag_instance = AsyncMock()
mock_graph_rag_class.return_value = mock_rag_instance
async def mock_query(**kwargs):
# Don't call explain_callback
return "Response text"
mock_rag_instance.query.side_effect = mock_query
# Setup message
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="Test query",
user="trustgraph",
collection="default",
streaming=False
)
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
mock_response_producer = AsyncMock()
mock_provenance_producer = AsyncMock()
def flow_router(service_name):
if service_name == "response":
return mock_response_producer
elif service_name == "explainability":
return mock_provenance_producer
return AsyncMock()
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 2 messages (chunk + empty chunk to close)
assert mock_response_producer.send.call_count == 2
# First is the response chunk
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "Response text"
assert chunk_msg.end_of_stream is True
# Second is empty chunk to close session
close_msg = mock_response_producer.send.call_args_list[1][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True

View file

@ -110,15 +110,25 @@ class AsyncSocketClient:
# Parse different chunk types
chunk = self._parse_chunk(resp)
yield chunk
if chunk is not None: # Skip provenance messages in streaming
yield chunk
# Check if this is the final chunk
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
# Check if this is the final message
# end_of_session indicates entire session is complete (including provenance)
# end_of_dialog is for agent dialogs
# complete is from the gateway envelope
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
break
def _parse_chunk(self, resp: Dict[str, Any]):
"""Parse response chunk into appropriate type"""
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type")
# Handle new GraphRAG message format with message_type
if message_type == "provenance":
# Provenance messages are not yielded to user - they're metadata
return None
if chunk_type == "thought":
return AgentThought(
@ -143,7 +153,7 @@ class AsyncSocketClient:
end_of_message=resp.get("end_of_message", False)
)
else:
# RAG-style chunk (or generic chunk)
# RAG-style chunk (or generic chunk with message_type="chunk")
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk(

View file

@ -11,7 +11,7 @@ import websockets
from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent
from . exceptions import ProtocolException, raise_from_error_dict
@ -310,15 +310,28 @@ class SocketClient:
# Parse different chunk types
chunk = self._parse_chunk(resp)
yield chunk
if chunk is not None: # Skip provenance messages in streaming
yield chunk
# Check if this is the final chunk
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
# Check if this is the final message
# end_of_session indicates entire session is complete (including provenance)
# end_of_dialog is for agent dialogs
# complete is from the gateway envelope
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
break
def _parse_chunk(self, resp: Dict[str, Any]) -> StreamingChunk:
"""Parse response chunk into appropriate type"""
def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]:
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type")
# Handle new GraphRAG message format with message_type
if message_type == "provenance":
if include_provenance:
# Return provenance event for explainability
return ProvenanceEvent(provenance_id=resp.get("provenance_id", ""))
# Provenance messages are not yielded to user - they're metadata
return None
if chunk_type == "thought":
return AgentThought(
@ -360,7 +373,7 @@ class SocketClient:
end_of_dialog=resp.get("end_of_dialog", False)
)
else:
# RAG-style chunk (or generic chunk)
# RAG-style chunk (or generic chunk with message_type="chunk")
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk(

View file

@ -202,3 +202,29 @@ class RAGChunk(StreamingChunk):
chunk_type: str = "rag"
end_of_stream: bool = False
error: Optional[Dict[str, str]] = None
@dataclasses.dataclass
class ProvenanceEvent:
"""
Provenance event for explainability.
Emitted during GraphRAG queries when explainable mode is enabled.
Each event represents a provenance node created during query processing.
Attributes:
provenance_id: URI of the provenance node (e.g., urn:trustgraph:session:abc123)
event_type: Type of provenance event (session, retrieval, selection, answer)
"""
provenance_id: str
event_type: str = "" # Derived from provenance_id (session, retrieval, selection, answer)
def __post_init__(self):
# Extract event type from provenance_id
if "session" in self.provenance_id:
self.event_type = "session"
elif "retrieval" in self.provenance_id:
self.event_type = "retrieval"
elif "selection" in self.provenance_id:
self.event_type = "selection"
elif "answer" in self.provenance_id:
self.event_type = "answer"

View file

@ -90,13 +90,31 @@ class GraphRagResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
result = {}
# Include response content (even if empty string)
# Include message_type
message_type = getattr(obj, "message_type", "")
if message_type:
result["message_type"] = message_type
# Include response content for chunk messages
if obj.response is not None:
result["response"] = obj.response
# Include end_of_stream flag
# Include explain_id for explain messages
explain_id = getattr(obj, "explain_id", None)
if explain_id:
result["explain_id"] = explain_id
# Include explain_collection for explain messages
explain_collection = getattr(obj, "explain_collection", None)
if explain_collection:
result["explain_collection"] = explain_collection
# Include end_of_stream flag (LLM stream complete)
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
# Include end_of_session flag (entire session complete)
result["end_of_session"] = getattr(obj, "end_of_session", False)
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
@ -105,5 +123,6 @@ class GraphRagResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
is_final = getattr(obj, 'end_of_stream', False)
# Session is complete when end_of_session is True
is_final = getattr(obj, 'end_of_session', False)
return self.from_pulsar(obj), is_final

View file

@ -40,6 +40,11 @@ from . uris import (
activity_uri,
statement_uri,
agent_uri,
# Query-time provenance URIs
query_session_uri,
retrieval_uri,
selection_uri,
answer_uri,
)
# Namespace constants
@ -58,6 +63,8 @@ from . namespaces import (
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL,
TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH,
# Query-time provenance predicates
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_CONTENT,
)
# Triple builders
@ -65,6 +72,11 @@ from . triples import (
document_triples,
derived_entity_triples,
triple_provenance_triples,
# Query-time provenance triple builders
query_session_triples,
retrieval_triples,
selection_triples,
answer_triples,
)
# Vocabulary bootstrap
@ -86,6 +98,11 @@ __all__ = [
"activity_uri",
"statement_uri",
"agent_uri",
# Query-time provenance URIs
"query_session_uri",
"retrieval_uri",
"selection_uri",
"answer_uri",
# Namespaces
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
"PROV_WAS_DERIVED_FROM", "PROV_WAS_GENERATED_BY",
@ -97,10 +114,17 @@ __all__ = [
"TG_CHUNK_SIZE", "TG_CHUNK_OVERLAP", "TG_COMPONENT_VERSION",
"TG_LLM_MODEL", "TG_ONTOLOGY", "TG_EMBEDDING_MODEL",
"TG_SOURCE_TEXT", "TG_SOURCE_CHAR_OFFSET", "TG_SOURCE_CHAR_LENGTH",
# Query-time provenance predicates
"TG_QUERY", "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_CONTENT",
# Triple builders
"document_triples",
"derived_entity_triples",
"triple_provenance_triples",
# Query-time provenance triple builders
"query_session_triples",
"retrieval_triples",
"selection_triples",
"answer_triples",
# Vocabulary
"get_vocabulary_triples",
"PROV_CLASS_LABELS",

View file

@ -58,3 +58,12 @@ TG_EMBEDDING_MODEL = TG + "embeddingModel"
TG_SOURCE_TEXT = TG + "sourceText"
TG_SOURCE_CHAR_OFFSET = TG + "sourceCharOffset"
TG_SOURCE_CHAR_LENGTH = TG + "sourceCharLength"
# Query-time provenance predicates
TG_QUERY = TG + "query"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document" # Reference to document in librarian

View file

@ -17,9 +17,12 @@ from . namespaces import (
TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH,
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_REIFIES,
# Query-time provenance predicates
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_CONTENT,
TG_DOCUMENT,
)
from . uris import activity_uri, agent_uri
from . uris import activity_uri, agent_uri, edge_selection_uri
def _iri(uri: str) -> Term:
@ -252,3 +255,177 @@ def triple_provenance_triples(
triples.append(_triple(act_uri, TG_ONTOLOGY, _iri(ontology_uri)))
return triples
# Query-time provenance triple builders
def query_session_triples(
session_uri: str,
query: str,
timestamp: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for a query session activity.
Creates:
- Activity declaration for the query session
- Query text and timestamp
Args:
session_uri: URI of the session (from query_session_uri)
query: The user's query text
timestamp: ISO timestamp (defaults to now)
Returns:
List of Triple objects
"""
if timestamp is None:
timestamp = datetime.utcnow().isoformat() + "Z"
return [
_triple(session_uri, RDF_TYPE, _iri(PROV_ACTIVITY)),
_triple(session_uri, RDFS_LABEL, _literal("GraphRAG query session")),
_triple(session_uri, PROV_STARTED_AT_TIME, _literal(timestamp)),
_triple(session_uri, TG_QUERY, _literal(query)),
]
def retrieval_triples(
retrieval_uri: str,
session_uri: str,
edge_count: int,
) -> List[Triple]:
"""
Build triples for a retrieval entity (all edges retrieved from subgraph).
Creates:
- Entity declaration for retrieval
- wasGeneratedBy link to session
- Edge count metadata
Args:
retrieval_uri: URI of the retrieval entity (from retrieval_uri)
session_uri: URI of the parent session
edge_count: Number of edges retrieved
Returns:
List of Triple objects
"""
return [
_triple(retrieval_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(retrieval_uri, RDFS_LABEL, _literal("Retrieved edges")),
_triple(retrieval_uri, PROV_WAS_GENERATED_BY, _iri(session_uri)),
_triple(retrieval_uri, TG_EDGE_COUNT, _literal(edge_count)),
]
def _quoted_triple(s: str, p: str, o: str) -> Term:
"""Create a quoted triple term (RDF-star) from string values."""
return Term(
type=TRIPLE,
triple=Triple(s=_iri(s), p=_iri(p), o=_iri(o))
)
def selection_triples(
selection_uri: str,
retrieval_uri: str,
selected_edges_with_reasoning: List[dict],
session_id: str = "",
) -> List[Triple]:
"""
Build triples for a selection entity (selected edges with reasoning).
Creates:
- Entity declaration for selection
- wasDerivedFrom link to retrieval
- For each selected edge: an edge selection entity with quoted triple and reasoning
Structure:
<selection> tg:selectedEdge <edge_sel_1> .
<edge_sel_1> tg:edge << <s> <p> <o> >> .
<edge_sel_1> tg:reasoning "reason" .
Args:
selection_uri: URI of the selection entity (from selection_uri)
retrieval_uri: URI of the parent retrieval entity
selected_edges_with_reasoning: List of dicts with 'edge' (s,p,o tuple) and 'reasoning'
session_id: Session UUID for generating edge selection URIs
Returns:
List of Triple objects
"""
triples = [
_triple(selection_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(selection_uri, RDFS_LABEL, _literal("Selected edges")),
_triple(selection_uri, PROV_WAS_DERIVED_FROM, _iri(retrieval_uri)),
]
# Add each selected edge with its reasoning via intermediate entity
for idx, edge_info in enumerate(selected_edges_with_reasoning):
edge = edge_info.get("edge")
reasoning = edge_info.get("reasoning", "")
if edge:
s, p, o = edge
# Create intermediate entity for this edge selection
edge_sel_uri = edge_selection_uri(session_id, idx)
# Link selection to edge selection entity
triples.append(
_triple(selection_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
)
# Attach quoted triple to edge selection entity
quoted = _quoted_triple(s, p, o)
triples.append(
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
)
# Attach reasoning to edge selection entity
if reasoning:
triples.append(
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
)
return triples
def answer_triples(
answer_uri: str,
selection_uri: str,
answer_text: str = "",
document_id: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for an answer entity (final synthesis text).
Creates:
- Entity declaration for answer
- wasDerivedFrom link to selection
- Either document reference (if document_id provided) or inline content
Args:
answer_uri: URI of the answer entity (from answer_uri)
selection_uri: URI of the parent selection entity
answer_text: The synthesized answer text (used if no document_id)
document_id: Optional librarian document ID (preferred over inline content)
Returns:
List of Triple objects
"""
triples = [
_triple(answer_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(answer_uri, RDFS_LABEL, _literal("GraphRAG answer")),
_triple(answer_uri, PROV_WAS_DERIVED_FROM, _iri(selection_uri)),
]
if document_id:
# Store reference to document in librarian (as IRI)
triples.append(_triple(answer_uri, TG_DOCUMENT, _iri(document_id)))
elif answer_text:
# Fallback: store inline content
triples.append(_triple(answer_uri, TG_CONTENT, _literal(answer_text)))
return triples

View file

@ -60,3 +60,75 @@ def statement_uri(stmt_id: str = None) -> str:
def agent_uri(component_name: str) -> str:
"""Generate URI for a TrustGraph component agent."""
return f"{TRUSTGRAPH_BASE}/agent/{_encode_id(component_name)}"
# Query-time provenance URIs
# These URIs use the urn:trustgraph: namespace to distinguish query-time
# provenance from extraction-time provenance (which uses https://trustgraph.ai/)
def query_session_uri(session_id: str = None) -> str:
"""
Generate URI for a query session activity.
Args:
session_id: Optional UUID string. Auto-generates if not provided.
Returns:
URN in format: urn:trustgraph:session:{uuid}
"""
if session_id is None:
session_id = str(uuid.uuid4())
return f"urn:trustgraph:session:{session_id}"
def retrieval_uri(session_id: str) -> str:
"""
Generate URI for a retrieval entity (edges retrieved from subgraph).
Args:
session_id: The session UUID (same as query_session_uri).
Returns:
URN in format: urn:trustgraph:prov:retrieval:{uuid}
"""
return f"urn:trustgraph:prov:retrieval:{session_id}"
def selection_uri(session_id: str) -> str:
"""
Generate URI for a selection entity (selected edges with reasoning).
Args:
session_id: The session UUID (same as query_session_uri).
Returns:
URN in format: urn:trustgraph:prov:selection:{uuid}
"""
return f"urn:trustgraph:prov:selection:{session_id}"
def answer_uri(session_id: str) -> str:
"""
Generate URI for an answer entity (final synthesis text).
Args:
session_id: The session UUID (same as query_session_uri).
Returns:
URN in format: urn:trustgraph:prov:answer:{uuid}
"""
return f"urn:trustgraph:prov:answer:{session_id}"
def edge_selection_uri(session_id: str, edge_index: int) -> str:
"""
Generate URI for an edge selection item (links edge to reasoning).
Args:
session_id: The session UUID.
edge_index: Index of this edge in the selection (0-based).
Returns:
URN in format: urn:trustgraph:prov:edge:{uuid}:{index}
"""
return f"urn:trustgraph:prov:edge:{session_id}:{edge_index}"

View file

@ -21,7 +21,11 @@ class GraphRagQuery:
class GraphRagResponse:
error: Error | None = None
response: str = ""
end_of_stream: bool = False
end_of_stream: bool = False # LLM response stream complete
explain_id: str | None = None # Single explain URI (announced as created)
explain_collection: str | None = None # Collection where explain was stored
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete
############################################################################

View file

@ -3,7 +3,11 @@ Uses the GraphRAG service to answer a question
"""
import argparse
import json
import os
import sys
import websockets
import asyncio
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
@ -15,11 +19,609 @@ default_triple_limit = 30
default_max_subgraph_size = 150
default_max_path_length = 2
# Provenance predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_REIFIES = TG + "reifies"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
def _get_event_type(prov_id):
"""Extract event type from provenance_id"""
if "session" in prov_id:
return "session"
elif "retrieval" in prov_id:
return "retrieval"
elif "selection" in prov_id:
return "selection"
elif "answer" in prov_id:
return "answer"
return "provenance"
def _format_provenance_details(event_type, triples):
"""Format provenance details based on event type and triples"""
lines = []
if event_type == "session":
# Show query and timestamp
for s, p, o in triples:
if p == TG_QUERY:
lines.append(f" Query: {o}")
elif p == PROV_STARTED_AT_TIME:
lines.append(f" Time: {o}")
elif event_type == "retrieval":
# Show edge count
for s, p, o in triples:
if p == TG_EDGE_COUNT:
lines.append(f" Edges retrieved: {o}")
elif event_type == "selection":
# For selection, just count edge selection URIs
# The actual edge details are fetched separately via edge_selections parameter
edge_sel_uris = []
for s, p, o in triples:
if p == TG_SELECTED_EDGE:
edge_sel_uris.append(o)
if edge_sel_uris:
lines.append(f" Selected {len(edge_sel_uris)} edge(s)")
elif event_type == "answer":
# Show content length (not full content - it's already streamed)
for s, p, o in triples:
if p == TG_CONTENT:
lines.append(f" Answer length: {len(o)} chars")
return lines
async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, debug=False):
"""Query triples for a provenance node (single attempt)"""
request = {
"id": "triples-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": prov_id},
"user": user,
"collection": collection,
"limit": 100
}
}
if debug:
print(f" [debug] querying triples for s={prov_id}", file=sys.stderr)
triples = []
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if debug:
print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr)
if response.get("id") != "triples-request":
continue
if "error" in response:
if debug:
print(f" [debug] error: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
# Handle triples response
# Response format: {"response": [triples...]}
# Each triple uses compact keys: "i" for iri, "v" for value, "t" for type
triple_list = resp.get("response", [])
for t in triple_list:
s = t.get("s", {}).get("i", t.get("s", {}).get("v", ""))
p = t.get("p", {}).get("i", t.get("p", {}).get("v", ""))
# Handle quoted triples (type "t") and regular values
o_term = t.get("o", {})
if o_term.get("t") == "t":
# Quoted triple - extract s, p, o from nested structure
tr = o_term.get("tr", {})
o = {
"s": tr.get("s", {}).get("i", ""),
"p": tr.get("p", {}).get("i", ""),
"o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")),
}
else:
o = o_term.get("i", o_term.get("v", ""))
triples.append((s, p, o))
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception: {e}", file=sys.stderr)
if debug:
print(f" [debug] got {len(triples)} triples", file=sys.stderr)
return triples
async def _query_triples(ws_url, flow_id, prov_id, user, collection, max_retries=5, retry_delay=0.2, debug=False):
"""Query triples for a provenance node with retries for race condition"""
for attempt in range(max_retries):
triples = await _query_triples_once(ws_url, flow_id, prov_id, user, collection, debug)
if triples:
return triples
# Wait before retry if empty (triples may not be stored yet)
if attempt < max_retries - 1:
if debug:
print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr)
await asyncio.sleep(retry_delay)
return []
async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, collection, debug=False):
"""
Query for provenance of an edge (s, p, o) in the knowledge graph.
Finds statements that reify the edge via tg:reifies, then follows
prov:wasDerivedFrom to find source documents.
Returns list of source URIs (chunks, pages, documents).
"""
# Query for statements that reify this edge: ?stmt tg:reifies <<s p o>>
request = {
"id": "edge-prov-request",
"service": "triples",
"flow": flow_id,
"request": {
"p": {"t": "i", "i": TG_REIFIES},
"o": {
"t": "t", # Quoted triple type
"tr": {
"s": {"t": "i", "i": edge_s},
"p": {"t": "i", "i": edge_p},
"o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o},
}
},
"user": user,
"collection": collection,
"limit": 10
}
}
if debug:
print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr)
stmt_uris = []
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "edge-prov-request":
continue
if "error" in response:
if debug:
print(f" [debug] error: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
for t in triple_list:
s = t.get("s", {}).get("i", "")
if s:
stmt_uris.append(s)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr)
if debug:
print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr)
# For each statement, query wasDerivedFrom to find sources
sources = []
for stmt_uri in stmt_uris:
# Query: stmt_uri prov:wasDerivedFrom ?source
request = {
"id": "derived-from-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": stmt_uri},
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
"user": user,
"collection": collection,
"limit": 10
}
}
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "derived-from-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
for t in triple_list:
o = t.get("o", {}).get("i", "")
if o:
sources.append(o)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr)
if debug:
print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr)
return sources
async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=False):
"""Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent."""
request = {
"id": "parent-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": uri},
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
"user": user,
"collection": collection,
"limit": 1
}
}
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "parent-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
if triple_list:
return triple_list[0].get("o", {}).get("i", None)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying parent: {e}", file=sys.stderr)
return None
async def _trace_provenance_chain(ws_url, flow_id, source_uri, user, collection, label_cache, debug=False):
"""
Trace the full provenance chain from a source URI up to the root document.
Returns a list of (uri, label) tuples from leaf to root.
"""
chain = []
current = source_uri
max_depth = 10 # Prevent infinite loops
for _ in range(max_depth):
if not current:
break
# Get label for current entity
label = await _query_label(ws_url, flow_id, current, user, collection, label_cache, debug)
chain.append((current, label))
# Get parent
parent = await _query_derived_from(ws_url, flow_id, current, user, collection, debug)
if not parent or parent == current:
break
current = parent
return chain
def _format_provenance_chain(chain):
"""
Format a provenance chain as a human-readable string.
Chain is [(uri, label), ...] from leaf to root.
"""
if not chain:
return ""
# Show labels, from leaf to root
labels = [label for uri, label in chain]
return "".join(labels)
def _is_iri(value):
"""Check if a value looks like an IRI."""
if not isinstance(value, str):
return False
return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:")
async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debug=False):
"""
Query for the rdfs:label of an IRI.
Uses label_cache to avoid repeated queries.
Returns the label if found, otherwise returns the IRI.
"""
if not _is_iri(iri):
return iri
# Check cache first
if iri in label_cache:
return label_cache[iri]
request = {
"id": "label-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": iri},
"p": {"t": "i", "i": RDFS_LABEL},
"user": user,
"collection": collection,
"limit": 1
}
}
label = iri # Default to IRI if no label found
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "label-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
if triple_list:
# Get the label value
o = triple_list[0].get("o", {})
label = o.get("v", o.get("i", iri))
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr)
# Cache the result
label_cache[iri] = label
return label
async def _resolve_edge_labels(ws_url, flow_id, edge_triple, user, collection, label_cache, debug=False):
"""
Resolve labels for all IRI components of an edge triple.
Returns (s_label, p_label, o_label).
"""
s = edge_triple.get("s", "?")
p = edge_triple.get("p", "?")
o = edge_triple.get("o", "?")
s_label = await _query_label(ws_url, flow_id, s, user, collection, label_cache, debug)
p_label = await _query_label(ws_url, flow_id, p, user, collection, label_cache, debug)
o_label = await _query_label(ws_url, flow_id, o, user, collection, label_cache, debug)
return s_label, p_label, o_label
async def _question_explainable(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
):
"""Execute graph RAG with explainability - shows provenance events with details"""
# Convert HTTP URL to WebSocket URL
if url.startswith("http://"):
ws_url = url.replace("http://", "ws://", 1)
elif url.startswith("https://"):
ws_url = url.replace("https://", "wss://", 1)
else:
ws_url = f"ws://{url}"
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
if token:
ws_url = f"{ws_url}?token={token}"
# Cache for label lookups to avoid repeated queries
label_cache = {}
request = {
"id": "cli-request",
"service": "graph-rag",
"flow": flow_id,
"request": {
"query": question,
"user": user,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length,
"streaming": True
}
}
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "cli-request":
continue
if "error" in response:
print(f"\nError: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
# Check for errors in response
if "error" in resp and resp["error"]:
err = resp["error"]
print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr)
break
message_type = resp.get("message_type", "")
if debug:
print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr)
if message_type == "explain":
# Display explain event with details
explain_id = resp.get("explain_id", "")
explain_collection = resp.get("explain_collection", "explainability")
if explain_id:
event_type = _get_event_type(explain_id)
print(f"\n [{event_type}] {explain_id}", file=sys.stderr)
# Query triples for this explain node (using explain collection from event)
triples = await _query_triples(
ws_url, flow_id, explain_id, user, explain_collection, debug=debug
)
# Format and display details
details = _format_provenance_details(event_type, triples)
for line in details:
print(line, file=sys.stderr)
# For selection events, query each edge selection for details
if event_type == "selection":
for s, p, o in triples:
if debug:
print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr)
if p == TG_SELECTED_EDGE and isinstance(o, str):
if debug:
print(f" [debug] querying edge selection: {o}", file=sys.stderr)
# Query the edge selection entity (using explain collection from event)
edge_triples = await _query_triples(
ws_url, flow_id, o, user, explain_collection, debug=debug
)
if debug:
print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr)
# Extract edge and reasoning
edge_triple = None # Store the actual triple for provenance lookup
reasoning = None
for es, ep, eo in edge_triples:
if debug:
print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr)
if ep == TG_EDGE and isinstance(eo, dict):
# eo is a quoted triple dict
edge_triple = eo
elif ep == TG_REASONING:
reasoning = eo
if edge_triple:
# Resolve labels for edge components
s_label, p_label, o_label = await _resolve_edge_labels(
ws_url, flow_id, edge_triple, user, collection,
label_cache, debug=debug
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if reasoning:
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
print(f" Reason: {r_short}", file=sys.stderr)
# Trace edge provenance in the user's collection (not explainability)
if edge_triple:
sources = await _query_edge_provenance(
ws_url, flow_id,
edge_triple.get("s", ""),
edge_triple.get("p", ""),
edge_triple.get("o", ""),
user, collection, # Use the query collection, not explainability
debug=debug
)
if sources:
for src in sources:
# Trace full chain from source to root document
chain = await _trace_provenance_chain(
ws_url, flow_id, src, user, collection,
label_cache, debug=debug
)
chain_str = _format_provenance_chain(chain)
print(f" Source: {chain_str}", file=sys.stderr)
elif message_type == "chunk" or not message_type:
# Display response chunk
chunk = resp.get("response", "")
if chunk:
print(chunk, end="", flush=True)
# Check if session is complete
if resp.get("end_of_session"):
break
print() # Final newline
def question(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, streaming=True, token=None
max_subgraph_size, max_path_length, streaming=True, token=None,
explainable=False, debug=False
):
# Explainable mode uses direct websocket to capture provenance events
if explainable:
asyncio.run(_question_explainable(
url=url,
flow_id=flow_id,
question=question,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
token=token,
debug=debug
))
return
# Create API client
api = Api(url=url, token=token)
@ -138,6 +740,18 @@ def main():
help='Disable streaming (use non-streaming mode)'
)
parser.add_argument(
'-x', '--explainable',
action='store_true',
help='Show provenance events for explainability (implies streaming)'
)
parser.add_argument(
'--debug',
action='store_true',
help='Show debug output for troubleshooting'
)
args = parser.parse_args()
try:
@ -154,6 +768,8 @@ def main():
max_path_length=args.max_path_length,
streaming=not args.no_streaming,
token=args.token,
explainable=args.explainable,
debug=args.debug,
)
except Exception as e:

View file

@ -13,7 +13,7 @@ from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
)
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK
from .... base import TriplesQueryService
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
@ -23,6 +23,36 @@ logger = logging.getLogger(__name__)
default_ident = "triples-query"
def serialize_triple(triple):
"""Serialize a Triple object to JSON for querying (must match storage format)."""
if triple is None:
return None
def term_to_dict(term):
if term is None:
return None
result = {"type": term.type}
if term.type == IRI:
result["iri"] = term.iri
elif term.type == LITERAL:
result["value"] = term.value
if term.datatype:
result["datatype"] = term.datatype
if term.language:
result["language"] = term.language
elif term.type == BLANK:
result["id"] = term.id
elif term.type == TRIPLE:
result["triple"] = serialize_triple(term.triple)
return result
return json.dumps({
"s": term_to_dict(triple.s),
"p": term_to_dict(triple.p),
"o": term_to_dict(triple.o),
})
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
@ -31,6 +61,9 @@ def get_term_value(term):
return term.iri
elif term.type == LITERAL:
return term.value
elif term.type == TRIPLE:
# Serialize nested triple to JSON (must match storage format)
return serialize_triple(term.triple)
else:
# For blank nodes or other types, use id or value
return term.id or term.value
@ -66,51 +99,50 @@ def deserialize_term(term_dict):
return Term(type=LITERAL, value=str(term_dict))
def create_term(value, otype=None, dtype=None, lang=None):
def create_term(value, term_type=None, datatype=None, language=None):
"""
Create a Term from a string value, optionally using type metadata.
Args:
value: The string value
otype: Object type - 'u' (URI), 'l' (literal), 't' (triple)
dtype: XSD datatype (for literals)
lang: Language tag (for literals)
term_type: 'u' (IRI), 'l' (literal), 't' (triple)
datatype: XSD datatype for literals
language: Language tag for literals
If otype is provided, uses it to determine Term type.
Otherwise falls back to URL detection heuristic.
If term_type is provided, uses it to determine Term type.
Otherwise falls back to URL detection heuristic for object values.
"""
if otype is not None:
if otype == 'u':
return Term(type=IRI, iri=value)
elif otype == 'l':
return Term(
type=LITERAL,
value=value,
datatype=dtype or "",
language=lang or ""
)
elif otype == 't':
# Triple/reification - parse JSON and create nested Triple
try:
triple_data = json.loads(value) if isinstance(value, str) else value
if isinstance(triple_data, dict):
return Term(
type=TRIPLE,
triple=Triple(
s=deserialize_term(triple_data.get("s")),
p=deserialize_term(triple_data.get("p")),
o=deserialize_term(triple_data.get("o")),
)
if term_type == 'u':
return Term(type=IRI, iri=value)
elif term_type == 'l':
return Term(
type=LITERAL,
value=value,
datatype=datatype or "",
language=language or ""
)
elif term_type == 't':
# Triple/reification - parse JSON and create nested Triple
try:
triple_data = json.loads(value) if isinstance(value, str) else value
if isinstance(triple_data, dict):
return Term(
type=TRIPLE,
triple=Triple(
s=deserialize_term(triple_data.get("s")),
p=deserialize_term(triple_data.get("p")),
o=deserialize_term(triple_data.get("o")),
)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse triple JSON: {e}")
# Fallback if parsing fails
return Term(type=LITERAL, value=str(value))
else:
# Unknown otype, fall back to heuristic
pass
)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse triple JSON: {e}")
# Fallback if parsing fails
return Term(type=LITERAL, value=str(value))
elif term_type is not None:
# Unknown term_type, fall back to heuristic
pass
# Heuristic fallback for backwards compatibility
# Heuristic fallback for backwards compatibility (object values only)
if value.startswith("http://") or value.startswith("https://"):
return Term(type=IRI, iri=value)
else:
@ -176,13 +208,13 @@ class Processor(TriplesQueryService):
o_val = get_term_value(query.o)
g_val = query.g # Already a string or None
# Helper to extract object metadata from result row
def get_o_metadata(t):
"""Extract otype/dtype/lang from result row if available"""
otype = getattr(t, 'otype', None)
dtype = getattr(t, 'dtype', None)
lang = getattr(t, 'lang', None)
return otype, dtype, lang
def get_object_metadata(row):
"""Extract term type metadata from result row"""
return (
getattr(row, 'otype', None),
getattr(row, 'dtype', None),
getattr(row, 'lang', None),
)
quads = []
@ -197,8 +229,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, p_val, o_val, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
else:
# SP specified
resp = self.tg.get_sp(
@ -207,8 +239,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, p_val, t.o, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
else:
if o_val is not None:
# SO specified
@ -218,8 +250,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, t.p, o_val, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
else:
# S only
resp = self.tg.get_s(
@ -228,8 +260,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, t.p, t.o, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, t.p, t.o, g, term_type, datatype, language))
else:
if p_val is not None:
if o_val is not None:
@ -240,8 +272,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, p_val, o_val, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
else:
# P only
resp = self.tg.get_p(
@ -250,8 +282,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, p_val, t.o, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
else:
if o_val is not None:
# O only
@ -261,8 +293,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, t.p, o_val, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
else:
# Nothing specified - get all
resp = self.tg.get_all(
@ -272,16 +304,17 @@ class Processor(TriplesQueryService):
for t in resp:
# Note: quads_by_collection uses 'd' for graph field
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, t.p, t.o, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
# Convert to Triple objects (with g field)
# Use otype/dtype/lang for proper Term reconstruction if available
# s and p are always IRIs in RDF
# Object uses term_type/datatype/language metadata from database
triples = [
Triple(
s=create_term(q[0]),
p=create_term(q[1]),
o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]),
s=create_term(q[0], term_type='u'),
p=create_term(q[1], term_type='u'),
o=create_term(q[2], term_type=q[4], datatype=q[5], language=q[6]),
g=q[3] if q[3] != DEFAULT_GRAPH else None
)
for q in quads
@ -311,12 +344,13 @@ class Processor(TriplesQueryService):
o_val = get_term_value(query.o)
g_val = query.g
# Helper to extract object metadata from result row
def get_o_metadata(t):
otype = getattr(t, 'otype', None)
dtype = getattr(t, 'dtype', None)
lang = getattr(t, 'lang', None)
return otype, dtype, lang
def get_object_metadata(row):
"""Extract term type metadata from result row"""
return (
getattr(row, 'otype', None),
getattr(row, 'dtype', None),
getattr(row, 'lang', None),
)
# For streaming, we need to execute with fetch_size
# Use the collection table for get_all queries (most common streaming case)
@ -345,12 +379,13 @@ class Processor(TriplesQueryService):
break
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(row)
term_type, datatype, language = get_object_metadata(row)
# s and p are always IRIs in RDF
triple = Triple(
s=create_term(row.s),
p=create_term(row.p),
o=create_term(row.o, otype=otype, dtype=dtype, lang=lang),
s=create_term(row.s, term_type='u'),
p=create_term(row.p, term_type='u'),
o=create_term(row.o, term_type=term_type, datatype=datatype, language=language),
g=g if g != DEFAULT_GRAPH else None
)
batch.append(triple)

View file

@ -1,11 +1,27 @@
import asyncio
import hashlib
import json
import logging
import time
import uuid
from collections import OrderedDict
from datetime import datetime
from ... schema import IRI, LITERAL
# Provenance imports
from trustgraph.provenance import (
query_session_uri,
retrieval_uri as make_retrieval_uri,
selection_uri as make_selection_uri,
answer_uri as make_answer_uri,
query_session_triples,
retrieval_triples,
selection_triples,
answer_triples,
)
# Module logger
logger = logging.getLogger(__name__)
@ -23,6 +39,12 @@ def term_to_string(term):
# Fallback
return term.iri or term.value or str(term)
def edge_id(s, p, o):
"""Generate an 8-character hash ID for an edge (s, p, o)."""
edge_str = f"{s}|{p}|{o}"
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
@ -258,7 +280,14 @@ class Query:
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
"""
Get subgraph with labels resolved for display.
Returns:
tuple: (labeled_edges, uri_map) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
"""
subgraph = await self.get_subgraph(query)
# Filter out label triples
@ -281,27 +310,33 @@ class Query:
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph
sg2 = []
# Apply labels to subgraph and build URI mapping
labeled_edges = []
uri_map = {} # Maps edge_id of labeled edge -> original URI triple
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
)
sg2.append(labeled_triple)
labeled_edges.append(labeled_triple)
sg2 = sg2[0:self.max_subgraph_size]
# Map from labeled edge ID to original URIs
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
uri_map[labeled_eid] = (s, p, o)
labeled_edges = labeled_edges[0:self.max_subgraph_size]
if self.verbose:
logger.debug("Subgraph:")
for edge in sg2:
for edge in labeled_edges:
logger.debug(f" {str(edge)}")
if self.verbose:
logger.debug("Done.")
return sg2
return labeled_edges, uri_map
class GraphRag:
"""
@ -335,11 +370,44 @@ class GraphRag:
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, streaming = False, chunk_callback = None,
explain_callback = None, save_answer_callback = None,
):
"""
Execute a GraphRAG query with real-time explainability tracking.
Args:
query: The query string
user: User identifier
collection: Collection identifier
entity_limit: Max entities to retrieve
triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
str: The synthesized answer text
"""
if self.verbose:
logger.debug("Constructing prompt...")
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
session_uri = query_session_uri(session_id)
ret_uri = make_retrieval_uri(session_id)
sel_uri = make_selection_uri(session_id)
ans_uri = make_answer_uri(session_id)
timestamp = datetime.utcnow().isoformat() + "Z"
# Emit session explainability immediately
if explain_callback:
session_triples = query_session_triples(session_uri, query, timestamp)
await explain_callback(session_triples, session_uri)
q = Query(
rag = self, user = user, collection = collection,
verbose = self.verbose, entity_limit = entity_limit,
@ -348,24 +416,171 @@ class GraphRag:
max_path_length = max_path_length,
)
kg = await q.get_labelgraph(query)
kg, uri_map = await q.get_labelgraph(query)
# Emit retrieval explain after graph retrieval completes
if explain_callback:
ret_triples = retrieval_triples(ret_uri, session_uri, len(kg))
await explain_callback(ret_triples, ret_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
if streaming and chunk_callback:
resp = await self.prompt_client.kg_prompt(
query, kg,
streaming=True,
chunk_callback=chunk_callback
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {}
edges_with_ids = []
for s, p, o in kg:
eid = edge_id(s, p, o)
edge_map[eid] = (s, p, o)
edges_with_ids.append({
"id": eid,
"s": s,
"p": p,
"o": o
})
if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1: Edge Selection - LLM selects relevant edges with reasoning
selection_response = await self.prompt_client.prompt(
"kg-edge-selection",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
if self.verbose:
logger.debug(f"Edge selection response: {selection_response}")
# Parse response to get selected edge IDs and reasoning
# Response can be a string (JSONL) or a list (JSON array)
selected_ids = set()
selected_edges_with_reasoning = [] # For explain
if isinstance(selection_response, list):
# JSON array response
for obj in selection_response:
if isinstance(obj, dict) and "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
elif isinstance(selection_response, str):
# JSONL string response
for line in selection_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
if "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
except json.JSONDecodeError:
logger.warning(f"Failed to parse edge selection line: {line}")
continue
if self.verbose:
logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}")
# Filter to selected edges
selected_edges = []
for eid in selected_ids:
if eid in edge_map:
selected_edges.append(edge_map[eid])
if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges")
# Emit selection explain after edge selection completes
if explain_callback:
sel_triples = selection_triples(
sel_uri, ret_uri, selected_edges_with_reasoning, session_id
)
await explain_callback(sel_triples, sel_uri)
# Step 2: Synthesis - LLM generates answer from selected edges only
selected_edge_dicts = [
{"s": s, "p": p, "o": o}
for s, p, o in selected_edges
]
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
await self.prompt_client.prompt(
"kg-synthesis",
variables={
"query": query,
"knowledge": selected_edge_dicts
},
streaming=True,
chunk_callback=accumulating_callback
)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.kg_prompt(query, kg)
resp = await self.prompt_client.prompt(
"kg-synthesis",
variables={
"query": query,
"knowledge": selected_edge_dicts
}
)
if self.verbose:
logger.debug("Query processing complete")
# Emit answer explain after synthesis completes
if explain_callback:
answer_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian if callback provided
if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
answer_doc_id = f"urn:trustgraph:answer:{session_id}"
try:
await save_answer_callback(answer_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
answer_doc_id = None # Fall back to inline content
# Generate triples with document reference or inline content
ans_triples = answer_triples(
ans_uri, sel_uri,
answer_text="" if answer_doc_id else answer_text,
document_id=answer_doc_id,
)
await explain_callback(ans_triples, ans_uri)
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
return resp

View file

@ -4,18 +4,28 @@ Simple RAG service, performs query using graph RAG an LLM.
Input is query, output is response.
"""
import asyncio
import base64
import logging
import uuid
from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "graph-rag"
default_concurrency = 1
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor):
@ -28,6 +38,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
explainability_collection = params.get("explainability_collection", "explainability")
super(Processor, self).__init__(
**params | {
@ -37,6 +48,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"explainability_collection": explainability_collection,
}
)
@ -44,6 +56,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.explainability_collection = explainability_collection
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
@ -93,10 +106,163 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
ProducerSpec(
name = "explainability",
schema = Triples,
)
)
# Librarian client for storing answer content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
logger.info("Graph RAG service initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "GraphRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling input {id}...")
# Track explainability refs for end_of_session signaling
explainability_refs_emitted = []
# Real-time explainability callback - emits triples and IDs as they're generated
async def send_explainability(triples, explain_id):
# Send triples to explainability queue
await flow("explainability").send(Triples(
metadata=Metadata(
id=explain_id,
metadata=[],
user=v.user,
collection=self.explainability_collection,
),
triples=triples,
))
# Send explain ID and collection to response queue
await flow("response").send(
GraphRagResponse(
message_type="explain",
explain_id=explain_id,
explain_collection=self.explainability_collection,
),
properties={"id": id}
)
explainability_refs_emitted.append(explain_id)
# CRITICAL SECURITY: Create new GraphRag instance per request
# This ensures proper isolation between users and collections
# Flow clients are request-scoped and must not be shared
@ -108,13 +274,6 @@ class Processor(FlowProcessor):
verbose=True,
)
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling input {id}...")
if v.entity_limit:
entity_limit = v.entity_limit
else:
@ -135,6 +294,15 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
content=answer_text,
title=f"GraphRAG Answer: {v.query[:50]}...",
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
@ -142,6 +310,7 @@ class Processor(FlowProcessor):
async def send_chunk(chunk, end_of_stream):
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response=chunk,
end_of_stream=end_of_stream,
error=None
@ -149,34 +318,50 @@ class Processor(FlowProcessor):
properties={"id": id}
)
# Query with streaming enabled
# All chunks (including final one with end_of_stream=True) are sent via callback
await rag.query(
# Query with streaming and real-time explain
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
streaming = True,
chunk_callback = send_chunk,
explain_callback = send_explainability,
save_answer_callback = save_answer,
)
else:
# Non-streaming path (existing behavior)
# Non-streaming path with real-time explain
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
explain_callback = send_explainability,
save_answer_callback = save_answer,
)
# Send chunk with response
await flow("response").send(
GraphRagResponse(
response = response,
end_of_stream = True,
error = None
message_type="chunk",
response=response,
end_of_stream=True,
error=None,
),
properties = {"id": id}
properties={"id": id}
)
# Send final message to close session
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
),
properties={"id": id}
)
logger.info("Request processing complete")
except Exception as e:
@ -185,22 +370,18 @@ class Processor(FlowProcessor):
logger.debug("Sending error response...")
# Send error response with end_of_stream flag if streaming was requested
error_response = GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
)
# If streaming was requested, indicate stream end
if v.streaming:
error_response.end_of_stream = True
# Send error response and close session
await flow("response").send(
error_response,
properties = {"id": id}
GraphRagResponse(
message_type="chunk",
error=Error(
type="graph-rag-error",
message=str(e),
),
end_of_stream=True,
end_of_session=True,
),
properties={"id": id}
)
@staticmethod
@ -243,6 +424,12 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--explainability-collection',
default='explainability',
help=f'Collection for storing explainability triples (default: explainability)'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -443,13 +443,22 @@ class McpServer:
gen = manager.request("graph-rag", request_data, flow_id)
text_chunks = []
async for response in gen:
# Handle new message format with message_type
message_type = response.get("message_type", "chunk")
# Extract vectors from response
text = response.get("response", "")
break
return GraphRagResponse(response=text)
# Only collect text from chunk messages
if message_type == "chunk":
chunk_text = response.get("response", "")
if chunk_text:
text_chunks.append(chunk_text)
# Check if session is complete
if response.get("end_of_session"):
break
return GraphRagResponse(response="".join(text_chunks))
async def agent(
self,