mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-10 15:25:14 +02:00
GraphRAG Query-Time Explainability (#677)
Implements full explainability pipeline for GraphRAG queries, enabling
traceability from answers back to source documents.
Renamed throughout for clarity:
- provenance_callback → explain_callback
- provenance_id → explain_id
- provenance_collection → explain_collection
- message_type "provenance" → "explain"
- Queue name "provenance" → "explainability"
GraphRAG queries now emit explainability events as they execute:
1. Session - query text and timestamp
2. Retrieval - edges retrieved from subgraph
3. Selection - selected edges with LLM reasoning (JSONL with id +
reasoning)
4. Answer - reference to synthesized response
Events stream via explain_callback during query(), enabling
real-time UX.
- Answers stored in librarian service (not inline in graph - too large)
- Document ID as URN: urn:trustgraph:answer:{session_id}
- Graph stores tg:document reference (IRI) to librarian document
- Added librarian producer/consumer to graph-rag service
- get_labelgraph() now returns (labeled_edges, uri_map)
- uri_map maps edge_id(label_s, label_p, label_o) →
(uri_s, uri_p, uri_o)
- Explainability data stores original URIs, not labels
- Enables tracing edges back to reifying statements via tg:reifies
- Added serialize_triple() to query service (matches storage format)
- get_term_value() now handles TRIPLE type terms
- Enables querying by quoted triple in object position:
?stmt tg:reifies <<s p o>>
- Displays real-time explainability events during query
- Resolves rdfs:label for edge components (s, p, o)
- Traces source chain via prov:wasDerivedFrom to root document
- Output: "Source: Chunk 1 → Page 2 → Document Title"
- Label caching to avoid repeated queries
GraphRagResponse:
- explain_id: str | None
- explain_collection: str | None
- message_type: str ("chunk" or "explain")
- end_of_session: bool
trustgraph-base/trustgraph/provenance/:
- namespaces.py - Added TG_DOCUMENT predicate
- triples.py - answer_triples() supports document_id reference
- uris.py - Added edge_selection_uri()
trustgraph-base/trustgraph/schema/services/retrieval.py:
- GraphRagResponse with explain_id, explain_collection, end_of_session
trustgraph-flow/trustgraph/retrieval/graph_rag/:
- graph_rag.py - URI preservation, streaming answer accumulation
- rag.py - Librarian integration, real-time explain emission
trustgraph-flow/trustgraph/query/triples/cassandra/service.py:
- Quoted triple serialization for query matching
trustgraph-cli/trustgraph/cli/invoke_graph_rag.py:
- Full explainability display with label resolution and source tracing
This commit is contained in:
parent
d2d71f859d
commit
7a6197d8c3
24 changed files with 2001 additions and 323 deletions
|
|
@ -1,13 +1,27 @@
|
|||
type: object
|
||||
description: Graph RAG response
|
||||
description: Graph RAG response message
|
||||
properties:
|
||||
message_type:
|
||||
type: string
|
||||
description: Type of message - "chunk" for LLM response chunks, "provenance" for provenance announcements
|
||||
enum: [chunk, provenance]
|
||||
example: chunk
|
||||
response:
|
||||
type: string
|
||||
description: Generated response based on retrieved knowledge graph
|
||||
description: Generated response text (for chunk messages)
|
||||
example: Quantum physics and computer science intersect in quantum computing...
|
||||
end-of-stream:
|
||||
provenance_id:
|
||||
type: string
|
||||
description: Provenance node URI (for provenance messages)
|
||||
example: urn:trustgraph:session:abc123
|
||||
end_of_stream:
|
||||
type: boolean
|
||||
description: Indicates streaming is complete (streaming mode)
|
||||
description: Indicates LLM response stream is complete
|
||||
default: false
|
||||
example: true
|
||||
end_of_session:
|
||||
type: boolean
|
||||
description: Indicates entire session is complete (all messages sent)
|
||||
default: false
|
||||
example: true
|
||||
error:
|
||||
|
|
|
|||
|
|
@ -17,16 +17,18 @@ from trustgraph.messaging import TranslatorRegistry
|
|||
class TestRAGTranslatorCompletionFlags:
|
||||
"""Contract tests for RAG response translator completion flags"""
|
||||
|
||||
def test_graph_rag_translator_is_final_with_end_of_stream_true(self):
|
||||
def test_graph_rag_translator_is_final_with_end_of_session_true(self):
|
||||
"""
|
||||
Test that GraphRagResponseTranslator returns is_final=True
|
||||
when end_of_stream=True.
|
||||
when end_of_session=True.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="A small domesticated mammal.",
|
||||
message_type="chunk",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
|
|
@ -34,20 +36,23 @@ class TestRAGTranslatorCompletionFlags:
|
|||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_stream=True"
|
||||
assert is_final is True, "is_final must be True when end_of_session=True"
|
||||
assert response_dict["response"] == "A small domesticated mammal."
|
||||
assert response_dict["end_of_stream"] is True
|
||||
assert response_dict["end_of_session"] is True
|
||||
assert response_dict["message_type"] == "chunk"
|
||||
|
||||
def test_graph_rag_translator_is_final_with_end_of_stream_false(self):
|
||||
def test_graph_rag_translator_is_final_with_end_of_session_false(self):
|
||||
"""
|
||||
Test that GraphRagResponseTranslator returns is_final=False
|
||||
when end_of_stream=False.
|
||||
when end_of_session=False (even if end_of_stream=True).
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="Chunk 1",
|
||||
message_type="chunk",
|
||||
end_of_stream=False,
|
||||
end_of_session=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
|
|
@ -55,9 +60,55 @@ class TestRAGTranslatorCompletionFlags:
|
|||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_stream=False"
|
||||
assert is_final is False, "is_final must be False when end_of_session=False"
|
||||
assert response_dict["response"] == "Chunk 1"
|
||||
assert response_dict["end_of_stream"] is False
|
||||
assert response_dict["end_of_session"] is False
|
||||
|
||||
def test_graph_rag_translator_provenance_message(self):
|
||||
"""
|
||||
Test that GraphRagResponseTranslator handles provenance messages.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="",
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:session:abc123",
|
||||
end_of_stream=False,
|
||||
end_of_session=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
assert response_dict["message_type"] == "explain"
|
||||
assert response_dict["explain_id"] == "urn:trustgraph:session:abc123"
|
||||
|
||||
def test_graph_rag_translator_end_of_stream_not_final(self):
|
||||
"""
|
||||
Test that end_of_stream=True alone does NOT make is_final=True.
|
||||
The session continues with provenance messages after LLM stream completes.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="Final chunk",
|
||||
message_type="chunk",
|
||||
end_of_stream=True,
|
||||
end_of_session=False, # Session continues with provenance
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
|
||||
assert response_dict["end_of_stream"] is True
|
||||
assert response_dict["end_of_session"] is False
|
||||
|
||||
def test_document_rag_translator_is_final_with_end_of_stream_true(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -83,13 +83,25 @@ class TestGraphRagIntegration:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
"""Mock prompt client that generates realistic responses for two-step process"""
|
||||
client = AsyncMock()
|
||||
client.kg_prompt.return_value = (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
|
||||
# Mock responses for the two-step process:
|
||||
# 1. kg-edge-selection returns JSONL with edge IDs
|
||||
# 2. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Return empty selection (no edges selected) - valid JSONL
|
||||
return ""
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
return ""
|
||||
|
||||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -108,7 +120,7 @@ class TestGraphRagIntegration:
|
|||
async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client,
|
||||
mock_graph_embeddings_client, mock_triples_client,
|
||||
mock_prompt_client):
|
||||
"""Test complete GraphRAG pipeline from query to response"""
|
||||
"""Test complete GraphRAG pipeline from query to response with real-time provenance"""
|
||||
# Arrange
|
||||
query = "What is machine learning?"
|
||||
user = "test_user"
|
||||
|
|
@ -116,13 +128,20 @@ class TestGraphRagIntegration:
|
|||
entity_limit = 50
|
||||
triple_limit = 30
|
||||
|
||||
# Collect provenance events
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act
|
||||
result = await graph_rag.query(
|
||||
response = await graph_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit
|
||||
triple_limit=triple_limit,
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert - Verify service coordination
|
||||
|
|
@ -141,16 +160,19 @@ class TestGraphRagIntegration:
|
|||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt with knowledge graph
|
||||
mock_prompt_client.kg_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.kg_prompt.call_args
|
||||
assert call_args.args[0] == query # First arg is query
|
||||
assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples)
|
||||
# 4. Should call prompt twice (edge selection + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 2
|
||||
|
||||
# Verify final response
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
||||
# Verify provenance was emitted in real-time (4 events: session, retrieval, selection, answer)
|
||||
assert len(provenance_events) == 4
|
||||
for triples, prov_id in provenance_events:
|
||||
assert isinstance(triples, list)
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client,
|
||||
|
|
@ -206,19 +228,25 @@ class TestGraphRagIntegration:
|
|||
mock_graph_embeddings_client.query.return_value = [] # No entities found
|
||||
mock_triples_client.query_stream.return_value = [] # No triples found
|
||||
|
||||
# Collect provenance
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act
|
||||
result = await graph_rag.query(
|
||||
response = await graph_rag.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
collection="test_collection",
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should still call prompt client with empty knowledge graph
|
||||
mock_prompt_client.kg_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.kg_prompt.call_args
|
||||
assert isinstance(call_args.args[1], list) # kg should be a list
|
||||
assert result is not None
|
||||
# Should still call prompt client (twice: edge selection + synthesis)
|
||||
assert response is not None
|
||||
# Provenance should still be emitted (4 events)
|
||||
assert len(provenance_events) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):
|
||||
|
|
|
|||
|
|
@ -53,30 +53,34 @@ class TestGraphRagStreaming:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
|
||||
"""Mock prompt client with streaming support"""
|
||||
"""Mock prompt client with streaming support for two-stage GraphRAG"""
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
# Both modes return the same text
|
||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
# Full synthesis text
|
||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
chunks = []
|
||||
async for chunk in mock_streaming_llm_response():
|
||||
chunks.append(chunk)
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "kg-edge-selection":
|
||||
# Edge selection returns JSONL with IDs - simulate selecting first edge
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
chunks = []
|
||||
async for chunk in mock_streaming_llm_response():
|
||||
chunks.append(chunk)
|
||||
|
||||
# Send all chunks with end_of_stream=False except the last
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
# Send all chunks with end_of_stream=False except the last
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return full_text
|
||||
else:
|
||||
return full_text
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -93,18 +97,25 @@ class TestGraphRagStreaming:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector):
|
||||
"""Test basic GraphRAG streaming functionality"""
|
||||
"""Test basic GraphRAG streaming functionality with real-time provenance"""
|
||||
# Arrange
|
||||
query = "What is machine learning?"
|
||||
collector = streaming_chunk_collector()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
# Collect provenance events
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act - query() returns response, provenance via callback
|
||||
response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collector.collect
|
||||
chunk_callback=collector.collect,
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
@ -116,10 +127,15 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Verify full response matches concatenated chunks
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
assert response == full_from_chunks
|
||||
|
||||
# Verify content is reasonable
|
||||
assert "machine" in result.lower() or "learning" in result.lower()
|
||||
assert "machine" in response.lower() or "learning" in response.lower()
|
||||
|
||||
# Verify provenance was emitted in real-time (4 events)
|
||||
assert len(provenance_events) == 4
|
||||
for triples, prov_id in provenance_events:
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming):
|
||||
|
|
@ -130,7 +146,7 @@ class TestGraphRagStreaming:
|
|||
collection = "test_collection"
|
||||
|
||||
# Act - Non-streaming
|
||||
non_streaming_result = await graph_rag_streaming.query(
|
||||
non_streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -143,7 +159,7 @@ class TestGraphRagStreaming:
|
|||
async def collect(chunk, end_of_stream):
|
||||
streaming_chunks.append(chunk)
|
||||
|
||||
streaming_result = await graph_rag_streaming.query(
|
||||
streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -152,9 +168,9 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_result == non_streaming_result
|
||||
assert streaming_response == non_streaming_response
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_result
|
||||
assert "".join(streaming_chunks) == streaming_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||
|
|
@ -163,7 +179,7 @@ class TestGraphRagStreaming:
|
|||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -173,7 +189,7 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Assert
|
||||
assert callback.call_count > 0
|
||||
assert result is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify all callback invocations had string arguments
|
||||
for call in callback.call_args_list:
|
||||
|
|
@ -183,7 +199,7 @@ class TestGraphRagStreaming:
|
|||
async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming):
|
||||
"""Test streaming parameter without callback (should fall back to non-streaming)"""
|
||||
# Arrange & Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -192,8 +208,8 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should complete without error
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||
|
|
@ -204,7 +220,7 @@ class TestGraphRagStreaming:
|
|||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -213,7 +229,7 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should still produce streamed response
|
||||
assert result is not None
|
||||
assert response is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -44,18 +44,23 @@ class TestGraphRagStreamingProtocol:
|
|||
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
await chunk_callback("The", False)
|
||||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
else:
|
||||
return "The answer is here."
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Edge selection returns empty (no edges selected)
|
||||
return ""
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
await chunk_callback("The", False)
|
||||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
else:
|
||||
return "The answer is here."
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -327,20 +332,24 @@ class TestStreamingProtocolEdgeCases:
|
|||
# Arrange
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return ""
|
||||
else:
|
||||
return "textmore"
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
else:
|
||||
return "textmore"
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_with_empties
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
rag = GraphRag(
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
|
||||
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
prompt_client=client,
|
||||
|
|
|
|||
|
|
@ -547,21 +547,21 @@ class TestServiceHelperFunctions:
|
|||
"""Test cases for helper functions in service.py"""
|
||||
|
||||
def test_create_term_with_uri_otype(self):
|
||||
"""Test create_term creates IRI Term for otype='u'"""
|
||||
"""Test create_term creates IRI Term for term_type='u'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/Alice', otype='u')
|
||||
term = create_term('http://example.org/Alice', term_type='u')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/Alice'
|
||||
|
||||
def test_create_term_with_literal_otype(self):
|
||||
"""Test create_term creates LITERAL Term for otype='l'"""
|
||||
"""Test create_term creates LITERAL Term for term_type='l'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import LITERAL
|
||||
|
||||
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
|
||||
term = create_term('Alice Smith', term_type='l', datatype='xsd:string', language='en')
|
||||
|
||||
assert term.type == LITERAL
|
||||
assert term.value == 'Alice Smith'
|
||||
|
|
@ -569,7 +569,7 @@ class TestServiceHelperFunctions:
|
|||
assert term.language == 'en'
|
||||
|
||||
def test_create_term_with_triple_otype(self):
|
||||
"""Test create_term creates TRIPLE Term for otype='t' with valid JSON"""
|
||||
"""Test create_term creates TRIPLE Term for term_type='t' with valid JSON"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import TRIPLE, IRI
|
||||
import json
|
||||
|
|
@ -581,7 +581,7 @@ class TestServiceHelperFunctions:
|
|||
"o": {"type": "i", "iri": "http://example.org/Bob"},
|
||||
})
|
||||
|
||||
term = create_term(triple_json, otype='t')
|
||||
term = create_term(triple_json, term_type='t')
|
||||
|
||||
assert term.type == TRIPLE
|
||||
assert term.triple is not None
|
||||
|
|
|
|||
|
|
@ -96,20 +96,21 @@ class TestGraphRagResponseTranslator:
|
|||
assert is_final is False
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
# Test final chunk with empty content
|
||||
# Test final message with end_of_session=True
|
||||
final_response = GraphRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(final_response)
|
||||
|
||||
# Assert
|
||||
# Assert - is_final is based on end_of_session, not end_of_stream
|
||||
assert is_final is True
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
assert result["end_of_session"] is True
|
||||
|
||||
|
||||
class TestDocumentRagResponseTranslator:
|
||||
|
|
|
|||
|
|
@ -118,8 +118,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
|
|
@ -182,8 +182,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -219,8 +219,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -256,8 +256,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -293,8 +293,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -331,8 +331,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
assert result[0].p.value == 'all_predicate'
|
||||
assert result[0].s.iri == 'all_subject'
|
||||
assert result[0].p.iri == 'all_predicate'
|
||||
assert result[0].o.value == 'all_object'
|
||||
|
||||
def test_add_args_method(self):
|
||||
|
|
@ -637,8 +637,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -678,8 +678,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -802,7 +802,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
# Verify all results were returned
|
||||
assert len(result) == 5
|
||||
for i, triple in enumerate(result):
|
||||
assert triple.s.value == f'subject_{i}' # Mock returns literal values
|
||||
assert triple.s.iri == f'subject_{i}' # Mock returns literal values
|
||||
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
|
||||
assert triple.p.type == IRI
|
||||
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri
|
||||
|
|
|
|||
|
|
@ -540,41 +540,68 @@ class TestQuery:
|
|||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
# Call get_labelgraph
|
||||
result = await query.get_labelgraph("test query")
|
||||
|
||||
labeled_edges, uri_map = await query.get_labelgraph("test query")
|
||||
|
||||
# Verify get_subgraph was called
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
|
||||
# Verify label triples are filtered out
|
||||
assert len(result) == 2 # Label triple should be excluded
|
||||
|
||||
assert len(labeled_edges) == 2 # Label triple should be excluded
|
||||
|
||||
# Verify maybe_label was called for non-label triples
|
||||
expected_calls = [
|
||||
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
|
||||
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
|
||||
]
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
|
||||
# Verify result contains human-readable labels
|
||||
expected_result = [
|
||||
expected_edges = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert result == expected_result
|
||||
assert labeled_edges == expected_edges
|
||||
|
||||
# Verify uri_map maps labeled edges back to original URIs
|
||||
assert len(uri_map) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline"""
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance"""
|
||||
import json
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
||||
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
|
||||
# Mock prompt client response
|
||||
|
||||
# Mock prompt client responses for two-step process
|
||||
expected_response = "This is the RAG response"
|
||||
mock_prompt_client.kg_prompt.return_value = expected_response
|
||||
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
|
||||
# Compute the edge ID for the test edge
|
||||
test_edge_id = edge_id("Subject", "Predicate", "Object")
|
||||
|
||||
# Create uri_map for the test edge (maps labeled edge ID to original URIs)
|
||||
test_uri_map = {
|
||||
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
}
|
||||
|
||||
# Mock edge selection response (JSONL format)
|
||||
edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
|
||||
# Configure prompt mock to return different responses based on prompt name
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return edge_selection_response
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return expected_response
|
||||
return ""
|
||||
|
||||
mock_prompt_client.prompt = mock_prompt
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -583,39 +610,55 @@ class TestQuery:
|
|||
triples_client=mock_triples_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Mock the Query class behavior by patching get_labelgraph
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
|
||||
|
||||
# We need to patch the Query class's get_labelgraph method
|
||||
original_query_init = Query.__init__
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
|
||||
|
||||
def mock_query_init(self, *args, **kwargs):
|
||||
original_query_init(self, *args, **kwargs)
|
||||
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph
|
||||
|
||||
return test_labelgraph, test_uri_map
|
||||
|
||||
Query.__init__ = mock_query_init
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
|
||||
|
||||
# Collect provenance emitted via callback
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
try:
|
||||
# Call GraphRag.query
|
||||
result = await graph_rag.query(
|
||||
# Call GraphRag.query with provenance callback
|
||||
response = await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15
|
||||
triple_limit=15,
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Verify prompt client was called with knowledge graph and query
|
||||
mock_prompt_client.kg_prompt.assert_called_once_with("test query", test_labelgraph)
|
||||
|
||||
# Verify result
|
||||
assert result == expected_response
|
||||
|
||||
|
||||
# Verify response text
|
||||
assert response == expected_response
|
||||
|
||||
# Verify provenance was emitted incrementally (4 events: session, retrieval, selection, answer)
|
||||
assert len(provenance_events) == 4
|
||||
|
||||
# Verify each event has triples and a URN
|
||||
for triples, prov_id in provenance_events:
|
||||
assert isinstance(triples, list)
|
||||
assert len(triples) > 0
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
# Verify order: session, retrieval, selection, answer
|
||||
assert "session" in provenance_events[0][1]
|
||||
assert "retrieval" in provenance_events[1][1]
|
||||
assert "selection" in provenance_events[2][1]
|
||||
assert "answer" in provenance_events[3][1]
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
Query.__init__ = original_query_init
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Unit tests for GraphRAG service non-streaming mode.
|
||||
Tests that end_of_stream flag is correctly set in non-streaming responses.
|
||||
Unit tests for GraphRAG service message format.
|
||||
Tests the new message protocol with message_type, explain_id, and end_of_session.
|
||||
Real-time explainability emission via callback.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -11,16 +12,14 @@ from trustgraph.schema import GraphRagQuery, GraphRagResponse
|
|||
|
||||
|
||||
class TestGraphRagService:
|
||||
"""Test GraphRAG service non-streaming behavior"""
|
||||
"""Test GraphRAG service message protocol"""
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class):
|
||||
async def test_non_streaming_sends_chunk_then_provenance_messages(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that non-streaming mode sets end_of_stream=True in response.
|
||||
|
||||
This is a regression test for the bug where non-streaming responses
|
||||
didn't set end_of_stream, causing clients to hang waiting for more data.
|
||||
Test that non-streaming mode sends real-time provenance messages
|
||||
followed by chunk message with response.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
|
|
@ -32,10 +31,22 @@ class TestGraphRagService:
|
|||
max_path_length=2
|
||||
)
|
||||
|
||||
# Setup mock GraphRag instance
|
||||
# Setup mock GraphRag instance that calls explain_callback
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A small domesticated mammal."
|
||||
|
||||
# Mock query() to call the explain_callback with each provenance event
|
||||
async def mock_query(**kwargs):
|
||||
explain_callback = kwargs.get('explain_callback')
|
||||
if explain_callback:
|
||||
# Simulate real-time provenance emission
|
||||
await explain_callback([], "urn:trustgraph:session:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:selection:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:answer:test")
|
||||
return "A small domesticated mammal."
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
|
|
@ -47,7 +58,7 @@ class TestGraphRagService:
|
|||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
streaming=False # Non-streaming mode
|
||||
streaming=False
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
|
|
@ -55,30 +66,48 @@ class TestGraphRagService:
|
|||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
# Mock flow to return AsyncMock for clients and response producer
|
||||
mock_producer = AsyncMock()
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_provenance_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients
|
||||
return mock_response_producer
|
||||
elif service_name == "explainability":
|
||||
return mock_provenance_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: response was sent with end_of_stream=True
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, GraphRagResponse)
|
||||
assert sent_response.response == "A small domesticated mammal."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.error is None
|
||||
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
|
||||
assert mock_response_producer.send.call_count == 6
|
||||
|
||||
# First 4 messages are explain (emitted in real-time during query)
|
||||
for i in range(4):
|
||||
prov_msg = mock_response_producer.send.call_args_list[i][0][0]
|
||||
assert prov_msg.message_type == "explain"
|
||||
assert prov_msg.explain_id is not None
|
||||
|
||||
# 5th message is chunk with response
|
||||
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "A small domesticated mammal."
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# 6th message is empty chunk with end_of_session=True
|
||||
close_msg = mock_response_producer.send.call_args_list[5][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
|
||||
# Verify provenance triples were sent to provenance queue
|
||||
assert mock_provenance_producer.send.call_count == 4
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class):
|
||||
async def test_error_response_closes_session(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that error responses in non-streaming mode set end_of_stream=True.
|
||||
Test that error responses set end_of_session=True.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
|
|
@ -105,7 +134,7 @@ class TestGraphRagService:
|
|||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
streaming=False # Non-streaming mode
|
||||
streaming=False
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
|
|
@ -113,22 +142,93 @@ class TestGraphRagService:
|
|||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_provenance_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return mock_response_producer
|
||||
elif service_name == "explainability":
|
||||
return mock_provenance_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: error response was sent without end_of_stream (not streaming mode)
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
# Verify: error response was sent with session closed
|
||||
mock_response_producer.send.assert_called_once()
|
||||
sent_response = mock_response_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, GraphRagResponse)
|
||||
assert sent_response.response is None
|
||||
assert sent_response.message_type == "chunk"
|
||||
assert sent_response.error is not None
|
||||
assert sent_response.error.message == "Test error"
|
||||
# Note: error responses in non-streaming mode don't set end_of_stream
|
||||
# because streaming was never started
|
||||
assert sent_response.end_of_stream is True
|
||||
assert sent_response.end_of_session is True
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_provenance_sends_empty_chunk_to_close(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that when no provenance callback is invoked, an empty chunk closes the session.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2
|
||||
)
|
||||
|
||||
# Setup mock GraphRag instance that doesn't call provenance callback
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
# Don't call explain_callback
|
||||
return "Response text"
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
# Setup message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="Test query",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_provenance_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_response_producer
|
||||
elif service_name == "explainability":
|
||||
return mock_provenance_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 2 messages (chunk + empty chunk to close)
|
||||
assert mock_response_producer.send.call_count == 2
|
||||
|
||||
# First is the response chunk
|
||||
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "Response text"
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# Second is empty chunk to close session
|
||||
close_msg = mock_response_producer.send.call_args_list[1][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
|
|
|
|||
|
|
@ -110,15 +110,25 @@ class AsyncSocketClient:
|
|||
|
||||
# Parse different chunk types
|
||||
chunk = self._parse_chunk(resp)
|
||||
yield chunk
|
||||
if chunk is not None: # Skip provenance messages in streaming
|
||||
yield chunk
|
||||
|
||||
# Check if this is the final chunk
|
||||
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
|
||||
# Check if this is the final message
|
||||
# end_of_session indicates entire session is complete (including provenance)
|
||||
# end_of_dialog is for agent dialogs
|
||||
# complete is from the gateway envelope
|
||||
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
|
||||
break
|
||||
|
||||
def _parse_chunk(self, resp: Dict[str, Any]):
|
||||
"""Parse response chunk into appropriate type"""
|
||||
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
|
||||
chunk_type = resp.get("chunk_type")
|
||||
message_type = resp.get("message_type")
|
||||
|
||||
# Handle new GraphRAG message format with message_type
|
||||
if message_type == "provenance":
|
||||
# Provenance messages are not yielded to user - they're metadata
|
||||
return None
|
||||
|
||||
if chunk_type == "thought":
|
||||
return AgentThought(
|
||||
|
|
@ -143,7 +153,7 @@ class AsyncSocketClient:
|
|||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
else:
|
||||
# RAG-style chunk (or generic chunk)
|
||||
# RAG-style chunk (or generic chunk with message_type="chunk")
|
||||
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||
return RAGChunk(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import websockets
|
|||
from typing import Optional, Dict, Any, Iterator, Union, List
|
||||
from threading import Lock
|
||||
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent
|
||||
from . exceptions import ProtocolException, raise_from_error_dict
|
||||
|
||||
|
||||
|
|
@ -310,15 +310,28 @@ class SocketClient:
|
|||
|
||||
# Parse different chunk types
|
||||
chunk = self._parse_chunk(resp)
|
||||
yield chunk
|
||||
if chunk is not None: # Skip provenance messages in streaming
|
||||
yield chunk
|
||||
|
||||
# Check if this is the final chunk
|
||||
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
|
||||
# Check if this is the final message
|
||||
# end_of_session indicates entire session is complete (including provenance)
|
||||
# end_of_dialog is for agent dialogs
|
||||
# complete is from the gateway envelope
|
||||
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
|
||||
break
|
||||
|
||||
def _parse_chunk(self, resp: Dict[str, Any]) -> StreamingChunk:
|
||||
"""Parse response chunk into appropriate type"""
|
||||
def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]:
|
||||
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
|
||||
chunk_type = resp.get("chunk_type")
|
||||
message_type = resp.get("message_type")
|
||||
|
||||
# Handle new GraphRAG message format with message_type
|
||||
if message_type == "provenance":
|
||||
if include_provenance:
|
||||
# Return provenance event for explainability
|
||||
return ProvenanceEvent(provenance_id=resp.get("provenance_id", ""))
|
||||
# Provenance messages are not yielded to user - they're metadata
|
||||
return None
|
||||
|
||||
if chunk_type == "thought":
|
||||
return AgentThought(
|
||||
|
|
@ -360,7 +373,7 @@ class SocketClient:
|
|||
end_of_dialog=resp.get("end_of_dialog", False)
|
||||
)
|
||||
else:
|
||||
# RAG-style chunk (or generic chunk)
|
||||
# RAG-style chunk (or generic chunk with message_type="chunk")
|
||||
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||
return RAGChunk(
|
||||
|
|
|
|||
|
|
@ -202,3 +202,29 @@ class RAGChunk(StreamingChunk):
|
|||
chunk_type: str = "rag"
|
||||
end_of_stream: bool = False
|
||||
error: Optional[Dict[str, str]] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProvenanceEvent:
|
||||
"""
|
||||
Provenance event for explainability.
|
||||
|
||||
Emitted during GraphRAG queries when explainable mode is enabled.
|
||||
Each event represents a provenance node created during query processing.
|
||||
|
||||
Attributes:
|
||||
provenance_id: URI of the provenance node (e.g., urn:trustgraph:session:abc123)
|
||||
event_type: Type of provenance event (session, retrieval, selection, answer)
|
||||
"""
|
||||
provenance_id: str
|
||||
event_type: str = "" # Derived from provenance_id (session, retrieval, selection, answer)
|
||||
|
||||
def __post_init__(self):
|
||||
# Extract event type from provenance_id
|
||||
if "session" in self.provenance_id:
|
||||
self.event_type = "session"
|
||||
elif "retrieval" in self.provenance_id:
|
||||
self.event_type = "retrieval"
|
||||
elif "selection" in self.provenance_id:
|
||||
self.event_type = "selection"
|
||||
elif "answer" in self.provenance_id:
|
||||
self.event_type = "answer"
|
||||
|
|
|
|||
|
|
@ -90,13 +90,31 @@ class GraphRagResponseTranslator(MessageTranslator):
|
|||
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
# Include response content (even if empty string)
|
||||
# Include message_type
|
||||
message_type = getattr(obj, "message_type", "")
|
||||
if message_type:
|
||||
result["message_type"] = message_type
|
||||
|
||||
# Include response content for chunk messages
|
||||
if obj.response is not None:
|
||||
result["response"] = obj.response
|
||||
|
||||
# Include end_of_stream flag
|
||||
# Include explain_id for explain messages
|
||||
explain_id = getattr(obj, "explain_id", None)
|
||||
if explain_id:
|
||||
result["explain_id"] = explain_id
|
||||
|
||||
# Include explain_collection for explain messages
|
||||
explain_collection = getattr(obj, "explain_collection", None)
|
||||
if explain_collection:
|
||||
result["explain_collection"] = explain_collection
|
||||
|
||||
# Include end_of_stream flag (LLM stream complete)
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
# Include end_of_session flag (entire session complete)
|
||||
result["end_of_session"] = getattr(obj, "end_of_session", False)
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||
|
|
@ -105,5 +123,6 @@ class GraphRagResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
is_final = getattr(obj, 'end_of_stream', False)
|
||||
# Session is complete when end_of_session is True
|
||||
is_final = getattr(obj, 'end_of_session', False)
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
|
@ -40,6 +40,11 @@ from . uris import (
|
|||
activity_uri,
|
||||
statement_uri,
|
||||
agent_uri,
|
||||
# Query-time provenance URIs
|
||||
query_session_uri,
|
||||
retrieval_uri,
|
||||
selection_uri,
|
||||
answer_uri,
|
||||
)
|
||||
|
||||
# Namespace constants
|
||||
|
|
@ -58,6 +63,8 @@ from . namespaces import (
|
|||
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
|
||||
TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL,
|
||||
TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH,
|
||||
# Query-time provenance predicates
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_CONTENT,
|
||||
)
|
||||
|
||||
# Triple builders
|
||||
|
|
@ -65,6 +72,11 @@ from . triples import (
|
|||
document_triples,
|
||||
derived_entity_triples,
|
||||
triple_provenance_triples,
|
||||
# Query-time provenance triple builders
|
||||
query_session_triples,
|
||||
retrieval_triples,
|
||||
selection_triples,
|
||||
answer_triples,
|
||||
)
|
||||
|
||||
# Vocabulary bootstrap
|
||||
|
|
@ -86,6 +98,11 @@ __all__ = [
|
|||
"activity_uri",
|
||||
"statement_uri",
|
||||
"agent_uri",
|
||||
# Query-time provenance URIs
|
||||
"query_session_uri",
|
||||
"retrieval_uri",
|
||||
"selection_uri",
|
||||
"answer_uri",
|
||||
# Namespaces
|
||||
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
|
||||
"PROV_WAS_DERIVED_FROM", "PROV_WAS_GENERATED_BY",
|
||||
|
|
@ -97,10 +114,17 @@ __all__ = [
|
|||
"TG_CHUNK_SIZE", "TG_CHUNK_OVERLAP", "TG_COMPONENT_VERSION",
|
||||
"TG_LLM_MODEL", "TG_ONTOLOGY", "TG_EMBEDDING_MODEL",
|
||||
"TG_SOURCE_TEXT", "TG_SOURCE_CHAR_OFFSET", "TG_SOURCE_CHAR_LENGTH",
|
||||
# Query-time provenance predicates
|
||||
"TG_QUERY", "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_CONTENT",
|
||||
# Triple builders
|
||||
"document_triples",
|
||||
"derived_entity_triples",
|
||||
"triple_provenance_triples",
|
||||
# Query-time provenance triple builders
|
||||
"query_session_triples",
|
||||
"retrieval_triples",
|
||||
"selection_triples",
|
||||
"answer_triples",
|
||||
# Vocabulary
|
||||
"get_vocabulary_triples",
|
||||
"PROV_CLASS_LABELS",
|
||||
|
|
|
|||
|
|
@ -58,3 +58,12 @@ TG_EMBEDDING_MODEL = TG + "embeddingModel"
|
|||
TG_SOURCE_TEXT = TG + "sourceText"
|
||||
TG_SOURCE_CHAR_OFFSET = TG + "sourceCharOffset"
|
||||
TG_SOURCE_CHAR_LENGTH = TG + "sourceCharLength"
|
||||
|
||||
# Query-time provenance predicates
|
||||
TG_QUERY = TG + "query"
|
||||
TG_EDGE_COUNT = TG + "edgeCount"
|
||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_CONTENT = TG + "content"
|
||||
TG_DOCUMENT = TG + "document" # Reference to document in librarian
|
||||
|
|
|
|||
|
|
@ -17,9 +17,12 @@ from . namespaces import (
|
|||
TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH,
|
||||
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
|
||||
TG_LLM_MODEL, TG_ONTOLOGY, TG_REIFIES,
|
||||
# Query-time provenance predicates
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_CONTENT,
|
||||
TG_DOCUMENT,
|
||||
)
|
||||
|
||||
from . uris import activity_uri, agent_uri
|
||||
from . uris import activity_uri, agent_uri, edge_selection_uri
|
||||
|
||||
|
||||
def _iri(uri: str) -> Term:
|
||||
|
|
@ -252,3 +255,177 @@ def triple_provenance_triples(
|
|||
triples.append(_triple(act_uri, TG_ONTOLOGY, _iri(ontology_uri)))
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
# Query-time provenance triple builders
|
||||
|
||||
def query_session_triples(
|
||||
session_uri: str,
|
||||
query: str,
|
||||
timestamp: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a query session activity.
|
||||
|
||||
Creates:
|
||||
- Activity declaration for the query session
|
||||
- Query text and timestamp
|
||||
|
||||
Args:
|
||||
session_uri: URI of the session (from query_session_uri)
|
||||
query: The user's query text
|
||||
timestamp: ISO timestamp (defaults to now)
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
return [
|
||||
_triple(session_uri, RDF_TYPE, _iri(PROV_ACTIVITY)),
|
||||
_triple(session_uri, RDFS_LABEL, _literal("GraphRAG query session")),
|
||||
_triple(session_uri, PROV_STARTED_AT_TIME, _literal(timestamp)),
|
||||
_triple(session_uri, TG_QUERY, _literal(query)),
|
||||
]
|
||||
|
||||
|
||||
def retrieval_triples(
|
||||
retrieval_uri: str,
|
||||
session_uri: str,
|
||||
edge_count: int,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a retrieval entity (all edges retrieved from subgraph).
|
||||
|
||||
Creates:
|
||||
- Entity declaration for retrieval
|
||||
- wasGeneratedBy link to session
|
||||
- Edge count metadata
|
||||
|
||||
Args:
|
||||
retrieval_uri: URI of the retrieval entity (from retrieval_uri)
|
||||
session_uri: URI of the parent session
|
||||
edge_count: Number of edges retrieved
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
return [
|
||||
_triple(retrieval_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(retrieval_uri, RDFS_LABEL, _literal("Retrieved edges")),
|
||||
_triple(retrieval_uri, PROV_WAS_GENERATED_BY, _iri(session_uri)),
|
||||
_triple(retrieval_uri, TG_EDGE_COUNT, _literal(edge_count)),
|
||||
]
|
||||
|
||||
|
||||
def _quoted_triple(s: str, p: str, o: str) -> Term:
|
||||
"""Create a quoted triple term (RDF-star) from string values."""
|
||||
return Term(
|
||||
type=TRIPLE,
|
||||
triple=Triple(s=_iri(s), p=_iri(p), o=_iri(o))
|
||||
)
|
||||
|
||||
|
||||
def selection_triples(
|
||||
selection_uri: str,
|
||||
retrieval_uri: str,
|
||||
selected_edges_with_reasoning: List[dict],
|
||||
session_id: str = "",
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a selection entity (selected edges with reasoning).
|
||||
|
||||
Creates:
|
||||
- Entity declaration for selection
|
||||
- wasDerivedFrom link to retrieval
|
||||
- For each selected edge: an edge selection entity with quoted triple and reasoning
|
||||
|
||||
Structure:
|
||||
<selection> tg:selectedEdge <edge_sel_1> .
|
||||
<edge_sel_1> tg:edge << <s> <p> <o> >> .
|
||||
<edge_sel_1> tg:reasoning "reason" .
|
||||
|
||||
Args:
|
||||
selection_uri: URI of the selection entity (from selection_uri)
|
||||
retrieval_uri: URI of the parent retrieval entity
|
||||
selected_edges_with_reasoning: List of dicts with 'edge' (s,p,o tuple) and 'reasoning'
|
||||
session_id: Session UUID for generating edge selection URIs
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
triples = [
|
||||
_triple(selection_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(selection_uri, RDFS_LABEL, _literal("Selected edges")),
|
||||
_triple(selection_uri, PROV_WAS_DERIVED_FROM, _iri(retrieval_uri)),
|
||||
]
|
||||
|
||||
# Add each selected edge with its reasoning via intermediate entity
|
||||
for idx, edge_info in enumerate(selected_edges_with_reasoning):
|
||||
edge = edge_info.get("edge")
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
|
||||
if edge:
|
||||
s, p, o = edge
|
||||
|
||||
# Create intermediate entity for this edge selection
|
||||
edge_sel_uri = edge_selection_uri(session_id, idx)
|
||||
|
||||
# Link selection to edge selection entity
|
||||
triples.append(
|
||||
_triple(selection_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
|
||||
)
|
||||
|
||||
# Attach quoted triple to edge selection entity
|
||||
quoted = _quoted_triple(s, p, o)
|
||||
triples.append(
|
||||
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
|
||||
)
|
||||
|
||||
# Attach reasoning to edge selection entity
|
||||
if reasoning:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
|
||||
)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def answer_triples(
|
||||
answer_uri: str,
|
||||
selection_uri: str,
|
||||
answer_text: str = "",
|
||||
document_id: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for an answer entity (final synthesis text).
|
||||
|
||||
Creates:
|
||||
- Entity declaration for answer
|
||||
- wasDerivedFrom link to selection
|
||||
- Either document reference (if document_id provided) or inline content
|
||||
|
||||
Args:
|
||||
answer_uri: URI of the answer entity (from answer_uri)
|
||||
selection_uri: URI of the parent selection entity
|
||||
answer_text: The synthesized answer text (used if no document_id)
|
||||
document_id: Optional librarian document ID (preferred over inline content)
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
triples = [
|
||||
_triple(answer_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(answer_uri, RDFS_LABEL, _literal("GraphRAG answer")),
|
||||
_triple(answer_uri, PROV_WAS_DERIVED_FROM, _iri(selection_uri)),
|
||||
]
|
||||
|
||||
if document_id:
|
||||
# Store reference to document in librarian (as IRI)
|
||||
triples.append(_triple(answer_uri, TG_DOCUMENT, _iri(document_id)))
|
||||
elif answer_text:
|
||||
# Fallback: store inline content
|
||||
triples.append(_triple(answer_uri, TG_CONTENT, _literal(answer_text)))
|
||||
|
||||
return triples
|
||||
|
|
|
|||
|
|
@ -60,3 +60,75 @@ def statement_uri(stmt_id: str = None) -> str:
|
|||
def agent_uri(component_name: str) -> str:
|
||||
"""Generate URI for a TrustGraph component agent."""
|
||||
return f"{TRUSTGRAPH_BASE}/agent/{_encode_id(component_name)}"
|
||||
|
||||
|
||||
# Query-time provenance URIs
|
||||
# These URIs use the urn:trustgraph: namespace to distinguish query-time
|
||||
# provenance from extraction-time provenance (which uses https://trustgraph.ai/)
|
||||
|
||||
def query_session_uri(session_id: str = None) -> str:
|
||||
"""
|
||||
Generate URI for a query session activity.
|
||||
|
||||
Args:
|
||||
session_id: Optional UUID string. Auto-generates if not provided.
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:session:{uuid}
|
||||
"""
|
||||
if session_id is None:
|
||||
session_id = str(uuid.uuid4())
|
||||
return f"urn:trustgraph:session:{session_id}"
|
||||
|
||||
|
||||
def retrieval_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a retrieval entity (edges retrieved from subgraph).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID (same as query_session_uri).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:retrieval:{uuid}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:retrieval:{session_id}"
|
||||
|
||||
|
||||
def selection_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a selection entity (selected edges with reasoning).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID (same as query_session_uri).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:selection:{uuid}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:selection:{session_id}"
|
||||
|
||||
|
||||
def answer_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for an answer entity (final synthesis text).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID (same as query_session_uri).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:answer:{uuid}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:answer:{session_id}"
|
||||
|
||||
|
||||
def edge_selection_uri(session_id: str, edge_index: int) -> str:
|
||||
"""
|
||||
Generate URI for an edge selection item (links edge to reasoning).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
edge_index: Index of this edge in the selection (0-based).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:edge:{uuid}:{index}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:edge:{session_id}:{edge_index}"
|
||||
|
|
|
|||
|
|
@ -21,7 +21,11 @@ class GraphRagQuery:
|
|||
class GraphRagResponse:
|
||||
error: Error | None = None
|
||||
response: str = ""
|
||||
end_of_stream: bool = False
|
||||
end_of_stream: bool = False # LLM response stream complete
|
||||
explain_id: str | None = None # Single explain URI (announced as created)
|
||||
explain_collection: str | None = None # Collection where explain was stored
|
||||
message_type: str = "" # "chunk" or "explain"
|
||||
end_of_session: bool = False # Entire session complete
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,11 @@ Uses the GraphRAG service to answer a question
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import websockets
|
||||
import asyncio
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
|
|
@ -15,11 +19,609 @@ default_triple_limit = 30
|
|||
default_max_subgraph_size = 150
|
||||
default_max_path_length = 2
|
||||
|
||||
# Provenance predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
TG_EDGE_COUNT = TG + "edgeCount"
|
||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_CONTENT = TG + "content"
|
||||
TG_REIFIES = TG + "reifies"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
|
||||
def _get_event_type(prov_id):
|
||||
"""Extract event type from provenance_id"""
|
||||
if "session" in prov_id:
|
||||
return "session"
|
||||
elif "retrieval" in prov_id:
|
||||
return "retrieval"
|
||||
elif "selection" in prov_id:
|
||||
return "selection"
|
||||
elif "answer" in prov_id:
|
||||
return "answer"
|
||||
return "provenance"
|
||||
|
||||
|
||||
def _format_provenance_details(event_type, triples):
|
||||
"""Format provenance details based on event type and triples"""
|
||||
lines = []
|
||||
|
||||
if event_type == "session":
|
||||
# Show query and timestamp
|
||||
for s, p, o in triples:
|
||||
if p == TG_QUERY:
|
||||
lines.append(f" Query: {o}")
|
||||
elif p == PROV_STARTED_AT_TIME:
|
||||
lines.append(f" Time: {o}")
|
||||
|
||||
elif event_type == "retrieval":
|
||||
# Show edge count
|
||||
for s, p, o in triples:
|
||||
if p == TG_EDGE_COUNT:
|
||||
lines.append(f" Edges retrieved: {o}")
|
||||
|
||||
elif event_type == "selection":
|
||||
# For selection, just count edge selection URIs
|
||||
# The actual edge details are fetched separately via edge_selections parameter
|
||||
edge_sel_uris = []
|
||||
for s, p, o in triples:
|
||||
if p == TG_SELECTED_EDGE:
|
||||
edge_sel_uris.append(o)
|
||||
if edge_sel_uris:
|
||||
lines.append(f" Selected {len(edge_sel_uris)} edge(s)")
|
||||
|
||||
elif event_type == "answer":
|
||||
# Show content length (not full content - it's already streamed)
|
||||
for s, p, o in triples:
|
||||
if p == TG_CONTENT:
|
||||
lines.append(f" Answer length: {len(o)} chars")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, debug=False):
|
||||
"""Query triples for a provenance node (single attempt)"""
|
||||
request = {
|
||||
"id": "triples-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": prov_id},
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 100
|
||||
}
|
||||
}
|
||||
|
||||
if debug:
|
||||
print(f" [debug] querying triples for s={prov_id}", file=sys.stderr)
|
||||
|
||||
triples = []
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if debug:
|
||||
print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr)
|
||||
|
||||
if response.get("id") != "triples-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
if debug:
|
||||
print(f" [debug] error: {response['error']}", file=sys.stderr)
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
# Handle triples response
|
||||
# Response format: {"response": [triples...]}
|
||||
# Each triple uses compact keys: "i" for iri, "v" for value, "t" for type
|
||||
triple_list = resp.get("response", [])
|
||||
for t in triple_list:
|
||||
s = t.get("s", {}).get("i", t.get("s", {}).get("v", ""))
|
||||
p = t.get("p", {}).get("i", t.get("p", {}).get("v", ""))
|
||||
# Handle quoted triples (type "t") and regular values
|
||||
o_term = t.get("o", {})
|
||||
if o_term.get("t") == "t":
|
||||
# Quoted triple - extract s, p, o from nested structure
|
||||
tr = o_term.get("tr", {})
|
||||
o = {
|
||||
"s": tr.get("s", {}).get("i", ""),
|
||||
"p": tr.get("p", {}).get("i", ""),
|
||||
"o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")),
|
||||
}
|
||||
else:
|
||||
o = o_term.get("i", o_term.get("v", ""))
|
||||
triples.append((s, p, o))
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception: {e}", file=sys.stderr)
|
||||
|
||||
if debug:
|
||||
print(f" [debug] got {len(triples)} triples", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
async def _query_triples(ws_url, flow_id, prov_id, user, collection, max_retries=5, retry_delay=0.2, debug=False):
|
||||
"""Query triples for a provenance node with retries for race condition"""
|
||||
for attempt in range(max_retries):
|
||||
triples = await _query_triples_once(ws_url, flow_id, prov_id, user, collection, debug)
|
||||
if triples:
|
||||
return triples
|
||||
# Wait before retry if empty (triples may not be stored yet)
|
||||
if attempt < max_retries - 1:
|
||||
if debug:
|
||||
print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr)
|
||||
await asyncio.sleep(retry_delay)
|
||||
return []
|
||||
|
||||
|
||||
async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, collection, debug=False):
|
||||
"""
|
||||
Query for provenance of an edge (s, p, o) in the knowledge graph.
|
||||
|
||||
Finds statements that reify the edge via tg:reifies, then follows
|
||||
prov:wasDerivedFrom to find source documents.
|
||||
|
||||
Returns list of source URIs (chunks, pages, documents).
|
||||
"""
|
||||
# Query for statements that reify this edge: ?stmt tg:reifies <<s p o>>
|
||||
request = {
|
||||
"id": "edge-prov-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"p": {"t": "i", "i": TG_REIFIES},
|
||||
"o": {
|
||||
"t": "t", # Quoted triple type
|
||||
"tr": {
|
||||
"s": {"t": "i", "i": edge_s},
|
||||
"p": {"t": "i", "i": edge_p},
|
||||
"o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o},
|
||||
}
|
||||
},
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 10
|
||||
}
|
||||
}
|
||||
|
||||
if debug:
|
||||
print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr)
|
||||
|
||||
stmt_uris = []
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "edge-prov-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
if debug:
|
||||
print(f" [debug] error: {response['error']}", file=sys.stderr)
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
triple_list = resp.get("response", [])
|
||||
for t in triple_list:
|
||||
s = t.get("s", {}).get("i", "")
|
||||
if s:
|
||||
stmt_uris.append(s)
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr)
|
||||
|
||||
if debug:
|
||||
print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr)
|
||||
|
||||
# For each statement, query wasDerivedFrom to find sources
|
||||
sources = []
|
||||
for stmt_uri in stmt_uris:
|
||||
# Query: stmt_uri prov:wasDerivedFrom ?source
|
||||
request = {
|
||||
"id": "derived-from-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": stmt_uri},
|
||||
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 10
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "derived-from-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
triple_list = resp.get("response", [])
|
||||
for t in triple_list:
|
||||
o = t.get("o", {}).get("i", "")
|
||||
if o:
|
||||
sources.append(o)
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr)
|
||||
|
||||
if debug:
|
||||
print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=False):
|
||||
"""Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent."""
|
||||
request = {
|
||||
"id": "parent-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": uri},
|
||||
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 1
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "parent-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
triple_list = resp.get("response", [])
|
||||
if triple_list:
|
||||
return triple_list[0].get("o", {}).get("i", None)
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception querying parent: {e}", file=sys.stderr)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _trace_provenance_chain(ws_url, flow_id, source_uri, user, collection, label_cache, debug=False):
|
||||
"""
|
||||
Trace the full provenance chain from a source URI up to the root document.
|
||||
Returns a list of (uri, label) tuples from leaf to root.
|
||||
"""
|
||||
chain = []
|
||||
current = source_uri
|
||||
max_depth = 10 # Prevent infinite loops
|
||||
|
||||
for _ in range(max_depth):
|
||||
if not current:
|
||||
break
|
||||
|
||||
# Get label for current entity
|
||||
label = await _query_label(ws_url, flow_id, current, user, collection, label_cache, debug)
|
||||
chain.append((current, label))
|
||||
|
||||
# Get parent
|
||||
parent = await _query_derived_from(ws_url, flow_id, current, user, collection, debug)
|
||||
if not parent or parent == current:
|
||||
break
|
||||
current = parent
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
def _format_provenance_chain(chain):
|
||||
"""
|
||||
Format a provenance chain as a human-readable string.
|
||||
Chain is [(uri, label), ...] from leaf to root.
|
||||
"""
|
||||
if not chain:
|
||||
return ""
|
||||
|
||||
# Show labels, from leaf to root
|
||||
labels = [label for uri, label in chain]
|
||||
return " → ".join(labels)
|
||||
|
||||
|
||||
def _is_iri(value):
|
||||
"""Check if a value looks like an IRI."""
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:")
|
||||
|
||||
|
||||
async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debug=False):
|
||||
"""
|
||||
Query for the rdfs:label of an IRI.
|
||||
Uses label_cache to avoid repeated queries.
|
||||
Returns the label if found, otherwise returns the IRI.
|
||||
"""
|
||||
if not _is_iri(iri):
|
||||
return iri
|
||||
|
||||
# Check cache first
|
||||
if iri in label_cache:
|
||||
return label_cache[iri]
|
||||
|
||||
request = {
|
||||
"id": "label-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": iri},
|
||||
"p": {"t": "i", "i": RDFS_LABEL},
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 1
|
||||
}
|
||||
}
|
||||
|
||||
label = iri # Default to IRI if no label found
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "label-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
triple_list = resp.get("response", [])
|
||||
if triple_list:
|
||||
# Get the label value
|
||||
o = triple_list[0].get("o", {})
|
||||
label = o.get("v", o.get("i", iri))
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr)
|
||||
|
||||
# Cache the result
|
||||
label_cache[iri] = label
|
||||
return label
|
||||
|
||||
|
||||
async def _resolve_edge_labels(ws_url, flow_id, edge_triple, user, collection, label_cache, debug=False):
|
||||
"""
|
||||
Resolve labels for all IRI components of an edge triple.
|
||||
Returns (s_label, p_label, o_label).
|
||||
"""
|
||||
s = edge_triple.get("s", "?")
|
||||
p = edge_triple.get("p", "?")
|
||||
o = edge_triple.get("o", "?")
|
||||
|
||||
s_label = await _query_label(ws_url, flow_id, s, user, collection, label_cache, debug)
|
||||
p_label = await _query_label(ws_url, flow_id, p, user, collection, label_cache, debug)
|
||||
o_label = await _query_label(ws_url, flow_id, o, user, collection, label_cache, debug)
|
||||
|
||||
return s_label, p_label, o_label
|
||||
|
||||
|
||||
async def _question_explainable(
|
||||
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, token=None, debug=False
|
||||
):
|
||||
"""Execute graph RAG with explainability - shows provenance events with details"""
|
||||
# Convert HTTP URL to WebSocket URL
|
||||
if url.startswith("http://"):
|
||||
ws_url = url.replace("http://", "ws://", 1)
|
||||
elif url.startswith("https://"):
|
||||
ws_url = url.replace("https://", "wss://", 1)
|
||||
else:
|
||||
ws_url = f"ws://{url}"
|
||||
|
||||
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
|
||||
if token:
|
||||
ws_url = f"{ws_url}?token={token}"
|
||||
|
||||
# Cache for label lookups to avoid repeated queries
|
||||
label_cache = {}
|
||||
|
||||
request = {
|
||||
"id": "cli-request",
|
||||
"service": "graph-rag",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"query": question,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"entity-limit": entity_limit,
|
||||
"triple-limit": triple_limit,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-path-length": max_path_length,
|
||||
"streaming": True
|
||||
}
|
||||
}
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "cli-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
print(f"\nError: {response['error']}", file=sys.stderr)
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
|
||||
# Check for errors in response
|
||||
if "error" in resp and resp["error"]:
|
||||
err = resp["error"]
|
||||
print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr)
|
||||
break
|
||||
|
||||
message_type = resp.get("message_type", "")
|
||||
|
||||
if debug:
|
||||
print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr)
|
||||
|
||||
if message_type == "explain":
|
||||
# Display explain event with details
|
||||
explain_id = resp.get("explain_id", "")
|
||||
explain_collection = resp.get("explain_collection", "explainability")
|
||||
if explain_id:
|
||||
event_type = _get_event_type(explain_id)
|
||||
print(f"\n [{event_type}] {explain_id}", file=sys.stderr)
|
||||
|
||||
# Query triples for this explain node (using explain collection from event)
|
||||
triples = await _query_triples(
|
||||
ws_url, flow_id, explain_id, user, explain_collection, debug=debug
|
||||
)
|
||||
|
||||
# Format and display details
|
||||
details = _format_provenance_details(event_type, triples)
|
||||
for line in details:
|
||||
print(line, file=sys.stderr)
|
||||
|
||||
# For selection events, query each edge selection for details
|
||||
if event_type == "selection":
|
||||
for s, p, o in triples:
|
||||
if debug:
|
||||
print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr)
|
||||
if p == TG_SELECTED_EDGE and isinstance(o, str):
|
||||
if debug:
|
||||
print(f" [debug] querying edge selection: {o}", file=sys.stderr)
|
||||
# Query the edge selection entity (using explain collection from event)
|
||||
edge_triples = await _query_triples(
|
||||
ws_url, flow_id, o, user, explain_collection, debug=debug
|
||||
)
|
||||
if debug:
|
||||
print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr)
|
||||
# Extract edge and reasoning
|
||||
edge_triple = None # Store the actual triple for provenance lookup
|
||||
reasoning = None
|
||||
for es, ep, eo in edge_triples:
|
||||
if debug:
|
||||
print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr)
|
||||
if ep == TG_EDGE and isinstance(eo, dict):
|
||||
# eo is a quoted triple dict
|
||||
edge_triple = eo
|
||||
elif ep == TG_REASONING:
|
||||
reasoning = eo
|
||||
if edge_triple:
|
||||
# Resolve labels for edge components
|
||||
s_label, p_label, o_label = await _resolve_edge_labels(
|
||||
ws_url, flow_id, edge_triple, user, collection,
|
||||
label_cache, debug=debug
|
||||
)
|
||||
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
|
||||
if reasoning:
|
||||
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
|
||||
print(f" Reason: {r_short}", file=sys.stderr)
|
||||
|
||||
# Trace edge provenance in the user's collection (not explainability)
|
||||
if edge_triple:
|
||||
sources = await _query_edge_provenance(
|
||||
ws_url, flow_id,
|
||||
edge_triple.get("s", ""),
|
||||
edge_triple.get("p", ""),
|
||||
edge_triple.get("o", ""),
|
||||
user, collection, # Use the query collection, not explainability
|
||||
debug=debug
|
||||
)
|
||||
if sources:
|
||||
for src in sources:
|
||||
# Trace full chain from source to root document
|
||||
chain = await _trace_provenance_chain(
|
||||
ws_url, flow_id, src, user, collection,
|
||||
label_cache, debug=debug
|
||||
)
|
||||
chain_str = _format_provenance_chain(chain)
|
||||
print(f" Source: {chain_str}", file=sys.stderr)
|
||||
|
||||
elif message_type == "chunk" or not message_type:
|
||||
# Display response chunk
|
||||
chunk = resp.get("response", "")
|
||||
if chunk:
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
# Check if session is complete
|
||||
if resp.get("end_of_session"):
|
||||
break
|
||||
|
||||
print() # Final newline
|
||||
|
||||
|
||||
def question(
|
||||
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, streaming=True, token=None
|
||||
max_subgraph_size, max_path_length, streaming=True, token=None,
|
||||
explainable=False, debug=False
|
||||
):
|
||||
|
||||
# Explainable mode uses direct websocket to capture provenance events
|
||||
if explainable:
|
||||
asyncio.run(_question_explainable(
|
||||
url=url,
|
||||
flow_id=flow_id,
|
||||
question=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
token=token,
|
||||
debug=debug
|
||||
))
|
||||
return
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
|
||||
|
|
@ -138,6 +740,18 @@ def main():
|
|||
help='Disable streaming (use non-streaming mode)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--explainable',
|
||||
action='store_true',
|
||||
help='Show provenance events for explainability (implies streaming)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Show debug output for troubleshooting'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -154,6 +768,8 @@ def main():
|
|||
max_path_length=args.max_path_length,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from .... direct.cassandra_kg import (
|
|||
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
|
||||
)
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE
|
||||
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK
|
||||
from .... base import TriplesQueryService
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
|
|
@ -23,6 +23,36 @@ logger = logging.getLogger(__name__)
|
|||
default_ident = "triples-query"
|
||||
|
||||
|
||||
def serialize_triple(triple):
|
||||
"""Serialize a Triple object to JSON for querying (must match storage format)."""
|
||||
if triple is None:
|
||||
return None
|
||||
|
||||
def term_to_dict(term):
|
||||
if term is None:
|
||||
return None
|
||||
result = {"type": term.type}
|
||||
if term.type == IRI:
|
||||
result["iri"] = term.iri
|
||||
elif term.type == LITERAL:
|
||||
result["value"] = term.value
|
||||
if term.datatype:
|
||||
result["datatype"] = term.datatype
|
||||
if term.language:
|
||||
result["language"] = term.language
|
||||
elif term.type == BLANK:
|
||||
result["id"] = term.id
|
||||
elif term.type == TRIPLE:
|
||||
result["triple"] = serialize_triple(term.triple)
|
||||
return result
|
||||
|
||||
return json.dumps({
|
||||
"s": term_to_dict(triple.s),
|
||||
"p": term_to_dict(triple.p),
|
||||
"o": term_to_dict(triple.o),
|
||||
})
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
|
|
@ -31,6 +61,9 @@ def get_term_value(term):
|
|||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
elif term.type == TRIPLE:
|
||||
# Serialize nested triple to JSON (must match storage format)
|
||||
return serialize_triple(term.triple)
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
|
@ -66,51 +99,50 @@ def deserialize_term(term_dict):
|
|||
return Term(type=LITERAL, value=str(term_dict))
|
||||
|
||||
|
||||
def create_term(value, otype=None, dtype=None, lang=None):
|
||||
def create_term(value, term_type=None, datatype=None, language=None):
|
||||
"""
|
||||
Create a Term from a string value, optionally using type metadata.
|
||||
|
||||
Args:
|
||||
value: The string value
|
||||
otype: Object type - 'u' (URI), 'l' (literal), 't' (triple)
|
||||
dtype: XSD datatype (for literals)
|
||||
lang: Language tag (for literals)
|
||||
term_type: 'u' (IRI), 'l' (literal), 't' (triple)
|
||||
datatype: XSD datatype for literals
|
||||
language: Language tag for literals
|
||||
|
||||
If otype is provided, uses it to determine Term type.
|
||||
Otherwise falls back to URL detection heuristic.
|
||||
If term_type is provided, uses it to determine Term type.
|
||||
Otherwise falls back to URL detection heuristic for object values.
|
||||
"""
|
||||
if otype is not None:
|
||||
if otype == 'u':
|
||||
return Term(type=IRI, iri=value)
|
||||
elif otype == 'l':
|
||||
return Term(
|
||||
type=LITERAL,
|
||||
value=value,
|
||||
datatype=dtype or "",
|
||||
language=lang or ""
|
||||
)
|
||||
elif otype == 't':
|
||||
# Triple/reification - parse JSON and create nested Triple
|
||||
try:
|
||||
triple_data = json.loads(value) if isinstance(value, str) else value
|
||||
if isinstance(triple_data, dict):
|
||||
return Term(
|
||||
type=TRIPLE,
|
||||
triple=Triple(
|
||||
s=deserialize_term(triple_data.get("s")),
|
||||
p=deserialize_term(triple_data.get("p")),
|
||||
o=deserialize_term(triple_data.get("o")),
|
||||
)
|
||||
if term_type == 'u':
|
||||
return Term(type=IRI, iri=value)
|
||||
elif term_type == 'l':
|
||||
return Term(
|
||||
type=LITERAL,
|
||||
value=value,
|
||||
datatype=datatype or "",
|
||||
language=language or ""
|
||||
)
|
||||
elif term_type == 't':
|
||||
# Triple/reification - parse JSON and create nested Triple
|
||||
try:
|
||||
triple_data = json.loads(value) if isinstance(value, str) else value
|
||||
if isinstance(triple_data, dict):
|
||||
return Term(
|
||||
type=TRIPLE,
|
||||
triple=Triple(
|
||||
s=deserialize_term(triple_data.get("s")),
|
||||
p=deserialize_term(triple_data.get("p")),
|
||||
o=deserialize_term(triple_data.get("o")),
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse triple JSON: {e}")
|
||||
# Fallback if parsing fails
|
||||
return Term(type=LITERAL, value=str(value))
|
||||
else:
|
||||
# Unknown otype, fall back to heuristic
|
||||
pass
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse triple JSON: {e}")
|
||||
# Fallback if parsing fails
|
||||
return Term(type=LITERAL, value=str(value))
|
||||
elif term_type is not None:
|
||||
# Unknown term_type, fall back to heuristic
|
||||
pass
|
||||
|
||||
# Heuristic fallback for backwards compatibility
|
||||
# Heuristic fallback for backwards compatibility (object values only)
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
return Term(type=IRI, iri=value)
|
||||
else:
|
||||
|
|
@ -176,13 +208,13 @@ class Processor(TriplesQueryService):
|
|||
o_val = get_term_value(query.o)
|
||||
g_val = query.g # Already a string or None
|
||||
|
||||
# Helper to extract object metadata from result row
|
||||
def get_o_metadata(t):
|
||||
"""Extract otype/dtype/lang from result row if available"""
|
||||
otype = getattr(t, 'otype', None)
|
||||
dtype = getattr(t, 'dtype', None)
|
||||
lang = getattr(t, 'lang', None)
|
||||
return otype, dtype, lang
|
||||
def get_object_metadata(row):
|
||||
"""Extract term type metadata from result row"""
|
||||
return (
|
||||
getattr(row, 'otype', None),
|
||||
getattr(row, 'dtype', None),
|
||||
getattr(row, 'lang', None),
|
||||
)
|
||||
|
||||
quads = []
|
||||
|
||||
|
|
@ -197,8 +229,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, p_val, o_val, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# SP specified
|
||||
resp = self.tg.get_sp(
|
||||
|
|
@ -207,8 +239,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, p_val, t.o, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
|
||||
else:
|
||||
if o_val is not None:
|
||||
# SO specified
|
||||
|
|
@ -218,8 +250,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, t.p, o_val, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# S only
|
||||
resp = self.tg.get_s(
|
||||
|
|
@ -228,8 +260,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, t.p, t.o, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, t.p, t.o, g, term_type, datatype, language))
|
||||
else:
|
||||
if p_val is not None:
|
||||
if o_val is not None:
|
||||
|
|
@ -240,8 +272,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, p_val, o_val, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# P only
|
||||
resp = self.tg.get_p(
|
||||
|
|
@ -250,8 +282,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, p_val, t.o, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
|
||||
else:
|
||||
if o_val is not None:
|
||||
# O only
|
||||
|
|
@ -261,8 +293,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, t.p, o_val, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# Nothing specified - get all
|
||||
resp = self.tg.get_all(
|
||||
|
|
@ -272,16 +304,17 @@ class Processor(TriplesQueryService):
|
|||
for t in resp:
|
||||
# Note: quads_by_collection uses 'd' for graph field
|
||||
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, t.p, t.o, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
|
||||
|
||||
# Convert to Triple objects (with g field)
|
||||
# Use otype/dtype/lang for proper Term reconstruction if available
|
||||
# s and p are always IRIs in RDF
|
||||
# Object uses term_type/datatype/language metadata from database
|
||||
triples = [
|
||||
Triple(
|
||||
s=create_term(q[0]),
|
||||
p=create_term(q[1]),
|
||||
o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]),
|
||||
s=create_term(q[0], term_type='u'),
|
||||
p=create_term(q[1], term_type='u'),
|
||||
o=create_term(q[2], term_type=q[4], datatype=q[5], language=q[6]),
|
||||
g=q[3] if q[3] != DEFAULT_GRAPH else None
|
||||
)
|
||||
for q in quads
|
||||
|
|
@ -311,12 +344,13 @@ class Processor(TriplesQueryService):
|
|||
o_val = get_term_value(query.o)
|
||||
g_val = query.g
|
||||
|
||||
# Helper to extract object metadata from result row
|
||||
def get_o_metadata(t):
|
||||
otype = getattr(t, 'otype', None)
|
||||
dtype = getattr(t, 'dtype', None)
|
||||
lang = getattr(t, 'lang', None)
|
||||
return otype, dtype, lang
|
||||
def get_object_metadata(row):
|
||||
"""Extract term type metadata from result row"""
|
||||
return (
|
||||
getattr(row, 'otype', None),
|
||||
getattr(row, 'dtype', None),
|
||||
getattr(row, 'lang', None),
|
||||
)
|
||||
|
||||
# For streaming, we need to execute with fetch_size
|
||||
# Use the collection table for get_all queries (most common streaming case)
|
||||
|
|
@ -345,12 +379,13 @@ class Processor(TriplesQueryService):
|
|||
break
|
||||
|
||||
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(row)
|
||||
term_type, datatype, language = get_object_metadata(row)
|
||||
|
||||
# s and p are always IRIs in RDF
|
||||
triple = Triple(
|
||||
s=create_term(row.s),
|
||||
p=create_term(row.p),
|
||||
o=create_term(row.o, otype=otype, dtype=dtype, lang=lang),
|
||||
s=create_term(row.s, term_type='u'),
|
||||
p=create_term(row.p, term_type='u'),
|
||||
o=create_term(row.o, term_type=term_type, datatype=datatype, language=language),
|
||||
g=g if g != DEFAULT_GRAPH else None
|
||||
)
|
||||
batch.append(triple)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,27 @@
|
|||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
from ... schema import IRI, LITERAL
|
||||
|
||||
# Provenance imports
|
||||
from trustgraph.provenance import (
|
||||
query_session_uri,
|
||||
retrieval_uri as make_retrieval_uri,
|
||||
selection_uri as make_selection_uri,
|
||||
answer_uri as make_answer_uri,
|
||||
query_session_triples,
|
||||
retrieval_triples,
|
||||
selection_triples,
|
||||
answer_triples,
|
||||
)
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -23,6 +39,12 @@ def term_to_string(term):
|
|||
# Fallback
|
||||
return term.iri or term.value or str(term)
|
||||
|
||||
|
||||
def edge_id(s, p, o):
|
||||
"""Generate an 8-character hash ID for an edge (s, p, o)."""
|
||||
edge_str = f"{s}|{p}|{o}"
|
||||
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
|
|
@ -258,7 +280,14 @@ class Query:
|
|||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
"""
|
||||
Get subgraph with labels resolved for display.
|
||||
|
||||
Returns:
|
||||
tuple: (labeled_edges, uri_map) where:
|
||||
- labeled_edges: list of (label_s, label_p, label_o) tuples
|
||||
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
|
||||
"""
|
||||
subgraph = await self.get_subgraph(query)
|
||||
|
||||
# Filter out label triples
|
||||
|
|
@ -281,27 +310,33 @@ class Query:
|
|||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
|
||||
# Apply labels to subgraph
|
||||
sg2 = []
|
||||
# Apply labels to subgraph and build URI mapping
|
||||
labeled_edges = []
|
||||
uri_map = {} # Maps edge_id of labeled edge -> original URI triple
|
||||
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
)
|
||||
sg2.append(labeled_triple)
|
||||
labeled_edges.append(labeled_triple)
|
||||
|
||||
sg2 = sg2[0:self.max_subgraph_size]
|
||||
# Map from labeled edge ID to original URIs
|
||||
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
|
||||
uri_map[labeled_eid] = (s, p, o)
|
||||
|
||||
labeled_edges = labeled_edges[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Subgraph:")
|
||||
for edge in sg2:
|
||||
for edge in labeled_edges:
|
||||
logger.debug(f" {str(edge)}")
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
||||
return sg2
|
||||
return labeled_edges, uri_map
|
||||
|
||||
class GraphRag:
|
||||
"""
|
||||
|
|
@ -335,11 +370,44 @@ class GraphRag:
|
|||
self, query, user = "trustgraph", collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, streaming = False, chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
):
|
||||
"""
|
||||
Execute a GraphRAG query with real-time explainability tracking.
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
entity_limit: Max entities to retrieve
|
||||
triple_limit: Max triples per entity
|
||||
max_subgraph_size: Max edges in subgraph
|
||||
max_path_length: Max hops from seed entities
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for real-time explainability
|
||||
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
"""
|
||||
if self.verbose:
|
||||
logger.debug("Constructing prompt...")
|
||||
|
||||
# Generate explainability URIs upfront
|
||||
session_id = str(uuid.uuid4())
|
||||
session_uri = query_session_uri(session_id)
|
||||
ret_uri = make_retrieval_uri(session_id)
|
||||
sel_uri = make_selection_uri(session_id)
|
||||
ans_uri = make_answer_uri(session_id)
|
||||
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# Emit session explainability immediately
|
||||
if explain_callback:
|
||||
session_triples = query_session_triples(session_uri, query, timestamp)
|
||||
await explain_callback(session_triples, session_uri)
|
||||
|
||||
q = Query(
|
||||
rag = self, user = user, collection = collection,
|
||||
verbose = self.verbose, entity_limit = entity_limit,
|
||||
|
|
@ -348,24 +416,171 @@ class GraphRag:
|
|||
max_path_length = max_path_length,
|
||||
)
|
||||
|
||||
kg = await q.get_labelgraph(query)
|
||||
kg, uri_map = await q.get_labelgraph(query)
|
||||
|
||||
# Emit retrieval explain after graph retrieval completes
|
||||
if explain_callback:
|
||||
ret_triples = retrieval_triples(ret_uri, session_uri, len(kg))
|
||||
await explain_callback(ret_triples, ret_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
if streaming and chunk_callback:
|
||||
resp = await self.prompt_client.kg_prompt(
|
||||
query, kg,
|
||||
streaming=True,
|
||||
chunk_callback=chunk_callback
|
||||
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
|
||||
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
|
||||
edge_map = {}
|
||||
edges_with_ids = []
|
||||
for s, p, o in kg:
|
||||
eid = edge_id(s, p, o)
|
||||
edge_map[eid] = (s, p, o)
|
||||
edges_with_ids.append({
|
||||
"id": eid,
|
||||
"s": s,
|
||||
"p": p,
|
||||
"o": o
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1: Edge Selection - LLM selects relevant edges with reasoning
|
||||
selection_response = await self.prompt_client.prompt(
|
||||
"kg-edge-selection",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge selection response: {selection_response}")
|
||||
|
||||
# Parse response to get selected edge IDs and reasoning
|
||||
# Response can be a string (JSONL) or a list (JSON array)
|
||||
selected_ids = set()
|
||||
selected_edges_with_reasoning = [] # For explain
|
||||
|
||||
if isinstance(selection_response, list):
|
||||
# JSON array response
|
||||
for obj in selection_response:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
selected_ids.add(obj["id"])
|
||||
# Capture original URI edge (not labels) and reasoning for explain
|
||||
eid = obj["id"]
|
||||
if eid in uri_map:
|
||||
# Use original URIs for provenance tracing
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": obj.get("reasoning", ""),
|
||||
})
|
||||
elif isinstance(selection_response, str):
|
||||
# JSONL string response
|
||||
for line in selection_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
if "id" in obj:
|
||||
selected_ids.add(obj["id"])
|
||||
# Capture original URI edge (not labels) and reasoning for explain
|
||||
eid = obj["id"]
|
||||
if eid in uri_map:
|
||||
# Use original URIs for provenance tracing
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": obj.get("reasoning", ""),
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse edge selection line: {line}")
|
||||
continue
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}")
|
||||
|
||||
# Filter to selected edges
|
||||
selected_edges = []
|
||||
for eid in selected_ids:
|
||||
if eid in edge_map:
|
||||
selected_edges.append(edge_map[eid])
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Filtered to {len(selected_edges)} edges")
|
||||
|
||||
# Emit selection explain after edge selection completes
|
||||
if explain_callback:
|
||||
sel_triples = selection_triples(
|
||||
sel_uri, ret_uri, selected_edges_with_reasoning, session_id
|
||||
)
|
||||
await explain_callback(sel_triples, sel_uri)
|
||||
|
||||
# Step 2: Synthesis - LLM generates answer from selected edges only
|
||||
selected_edge_dicts = [
|
||||
{"s": s, "p": p, "o": o}
|
||||
for s, p, o in selected_edges
|
||||
]
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
||||
async def accumulating_callback(chunk, end_of_stream):
|
||||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts
|
||||
},
|
||||
streaming=True,
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.kg_prompt(query, kg)
|
||||
resp = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts
|
||||
}
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
||||
# Emit answer explain after synthesis completes
|
||||
if explain_callback:
|
||||
answer_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian if callback provided
|
||||
if save_answer_callback and answer_text:
|
||||
# Generate document ID as URN matching query-time provenance format
|
||||
answer_doc_id = f"urn:trustgraph:answer:{session_id}"
|
||||
try:
|
||||
await save_answer_callback(answer_doc_id, answer_text)
|
||||
if self.verbose:
|
||||
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
answer_doc_id = None # Fall back to inline content
|
||||
|
||||
# Generate triples with document reference or inline content
|
||||
ans_triples = answer_triples(
|
||||
ans_uri, sel_uri,
|
||||
answer_text="" if answer_doc_id else answer_text,
|
||||
document_id=answer_doc_id,
|
||||
)
|
||||
await explain_callback(ans_triples, ans_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Emitted explain for session {session_id}")
|
||||
|
||||
return resp
|
||||
|
||||
|
|
|
|||
|
|
@ -4,18 +4,28 @@ Simple RAG service, performs query using graph RAG an LLM.
|
|||
Input is query, output is response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from ... schema import GraphRagQuery, GraphRagResponse, Error
|
||||
from ... schema import Triples, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from . graph_rag import GraphRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "graph-rag"
|
||||
default_concurrency = 1
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -28,6 +38,7 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
explainability_collection = params.get("explainability_collection", "explainability")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -37,6 +48,7 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"explainability_collection": explainability_collection,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -44,6 +56,7 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.explainability_collection = explainability_collection
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
# Each user/collection combination MUST have isolated data access
|
||||
|
|
@ -93,10 +106,163 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for storing answer content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_librarian_requests = {}
|
||||
|
||||
logger.info("Graph RAG service initialized")
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_librarian_requests:
|
||||
future = self.pending_librarian_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
Save answer content to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the answer document
|
||||
user: User ID
|
||||
content: Answer text content
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or "GraphRAG Answer",
|
||||
document_type="answer",
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_librarian_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_librarian_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling input {id}...")
|
||||
|
||||
# Track explainability refs for end_of_session signaling
|
||||
explainability_refs_emitted = []
|
||||
|
||||
# Real-time explainability callback - emits triples and IDs as they're generated
|
||||
async def send_explainability(triples, explain_id):
|
||||
# Send triples to explainability queue
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=explain_id,
|
||||
metadata=[],
|
||||
user=v.user,
|
||||
collection=self.explainability_collection,
|
||||
),
|
||||
triples=triples,
|
||||
))
|
||||
|
||||
# Send explain ID and collection to response queue
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id=explain_id,
|
||||
explain_collection=self.explainability_collection,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
explainability_refs_emitted.append(explain_id)
|
||||
|
||||
# CRITICAL SECURITY: Create new GraphRag instance per request
|
||||
# This ensures proper isolation between users and collections
|
||||
# Flow clients are request-scoped and must not be shared
|
||||
|
|
@ -108,13 +274,6 @@ class Processor(FlowProcessor):
|
|||
verbose=True,
|
||||
)
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling input {id}...")
|
||||
|
||||
if v.entity_limit:
|
||||
entity_limit = v.entity_limit
|
||||
else:
|
||||
|
|
@ -135,6 +294,15 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
# Callback to save answer content to librarian
|
||||
async def save_answer(doc_id, answer_text):
|
||||
await self.save_answer_content(
|
||||
doc_id=doc_id,
|
||||
user=v.user,
|
||||
content=answer_text,
|
||||
title=f"GraphRAG Answer: {v.query[:50]}...",
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
|
|
@ -142,6 +310,7 @@ class Processor(FlowProcessor):
|
|||
async def send_chunk(chunk, end_of_stream):
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response=chunk,
|
||||
end_of_stream=end_of_stream,
|
||||
error=None
|
||||
|
|
@ -149,34 +318,50 @@ class Processor(FlowProcessor):
|
|||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Query with streaming enabled
|
||||
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||
await rag.query(
|
||||
# Query with streaming and real-time explain
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
)
|
||||
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
# Non-streaming path with real-time explain
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
)
|
||||
|
||||
# Send chunk with response
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = response,
|
||||
end_of_stream = True,
|
||||
error = None
|
||||
message_type="chunk",
|
||||
response=response,
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
),
|
||||
properties = {"id": id}
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Send final message to close session
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="",
|
||||
end_of_session=True,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.info("Request processing complete")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -185,22 +370,18 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug("Sending error response...")
|
||||
|
||||
# Send error response with end_of_stream flag if streaming was requested
|
||||
error_response = GraphRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "graph-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
)
|
||||
|
||||
# If streaming was requested, indicate stream end
|
||||
if v.streaming:
|
||||
error_response.end_of_stream = True
|
||||
|
||||
# Send error response and close session
|
||||
await flow("response").send(
|
||||
error_response,
|
||||
properties = {"id": id}
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
error=Error(
|
||||
type="graph-rag-error",
|
||||
message=str(e),
|
||||
),
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -243,6 +424,12 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--explainability-collection',
|
||||
default='explainability',
|
||||
help=f'Collection for storing explainability triples (default: explainability)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -443,13 +443,22 @@ class McpServer:
|
|||
|
||||
gen = manager.request("graph-rag", request_data, flow_id)
|
||||
|
||||
text_chunks = []
|
||||
async for response in gen:
|
||||
# Handle new message format with message_type
|
||||
message_type = response.get("message_type", "chunk")
|
||||
|
||||
# Extract vectors from response
|
||||
text = response.get("response", "")
|
||||
break
|
||||
|
||||
return GraphRagResponse(response=text)
|
||||
# Only collect text from chunk messages
|
||||
if message_type == "chunk":
|
||||
chunk_text = response.get("response", "")
|
||||
if chunk_text:
|
||||
text_chunks.append(chunk_text)
|
||||
|
||||
# Check if session is complete
|
||||
if response.get("end_of_session"):
|
||||
break
|
||||
|
||||
return GraphRagResponse(response="".join(text_chunks))
|
||||
|
||||
async def agent(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue