mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
GraphRAG Query-Time Explainability (#677)
Implements full explainability pipeline for GraphRAG queries, enabling
traceability from answers back to source documents.
Renamed throughout for clarity:
- provenance_callback → explain_callback
- provenance_id → explain_id
- provenance_collection → explain_collection
- message_type "provenance" → "explain"
- Queue name "provenance" → "explainability"
GraphRAG queries now emit explainability events as they execute:
1. Session - query text and timestamp
2. Retrieval - edges retrieved from subgraph
3. Selection - selected edges with LLM reasoning (JSONL with id +
reasoning)
4. Answer - reference to synthesized response
Events stream via explain_callback during query(), enabling
real-time UX.
- Answers stored in librarian service (not inline in graph - too large)
- Document ID as URN: urn:trustgraph:answer:{session_id}
- Graph stores tg:document reference (IRI) to librarian document
- Added librarian producer/consumer to graph-rag service
- get_labelgraph() now returns (labeled_edges, uri_map)
- uri_map maps edge_id(label_s, label_p, label_o) →
(uri_s, uri_p, uri_o)
- Explainability data stores original URIs, not labels
- Enables tracing edges back to reifying statements via tg:reifies
- Added serialize_triple() to query service (matches storage format)
- get_term_value() now handles TRIPLE type terms
- Enables querying by quoted triple in object position:
?stmt tg:reifies <<s p o>>
- Displays real-time explainability events during query
- Resolves rdfs:label for edge components (s, p, o)
- Traces source chain via prov:wasDerivedFrom to root document
- Output: "Source: Chunk 1 → Page 2 → Document Title"
- Label caching to avoid repeated queries
GraphRagResponse:
- explain_id: str | None
- explain_collection: str | None
- message_type: str ("chunk" or "explain")
- end_of_session: bool
trustgraph-base/trustgraph/provenance/:
- namespaces.py - Added TG_DOCUMENT predicate
- triples.py - answer_triples() supports document_id reference
- uris.py - Added edge_selection_uri()
trustgraph-base/trustgraph/schema/services/retrieval.py:
- GraphRagResponse with explain_id, explain_collection, end_of_session
trustgraph-flow/trustgraph/retrieval/graph_rag/:
- graph_rag.py - URI preservation, streaming answer accumulation
- rag.py - Librarian integration, real-time explain emission
trustgraph-flow/trustgraph/query/triples/cassandra/service.py:
- Quoted triple serialization for query matching
trustgraph-cli/trustgraph/cli/invoke_graph_rag.py:
- Full explainability display with label resolution and source tracing
This commit is contained in:
parent
d2d71f859d
commit
7a6197d8c3
24 changed files with 2001 additions and 323 deletions
|
|
@ -547,21 +547,21 @@ class TestServiceHelperFunctions:
|
|||
"""Test cases for helper functions in service.py"""
|
||||
|
||||
def test_create_term_with_uri_otype(self):
|
||||
"""Test create_term creates IRI Term for otype='u'"""
|
||||
"""Test create_term creates IRI Term for term_type='u'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/Alice', otype='u')
|
||||
term = create_term('http://example.org/Alice', term_type='u')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/Alice'
|
||||
|
||||
def test_create_term_with_literal_otype(self):
|
||||
"""Test create_term creates LITERAL Term for otype='l'"""
|
||||
"""Test create_term creates LITERAL Term for term_type='l'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import LITERAL
|
||||
|
||||
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
|
||||
term = create_term('Alice Smith', term_type='l', datatype='xsd:string', language='en')
|
||||
|
||||
assert term.type == LITERAL
|
||||
assert term.value == 'Alice Smith'
|
||||
|
|
@ -569,7 +569,7 @@ class TestServiceHelperFunctions:
|
|||
assert term.language == 'en'
|
||||
|
||||
def test_create_term_with_triple_otype(self):
|
||||
"""Test create_term creates TRIPLE Term for otype='t' with valid JSON"""
|
||||
"""Test create_term creates TRIPLE Term for term_type='t' with valid JSON"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import TRIPLE, IRI
|
||||
import json
|
||||
|
|
@ -581,7 +581,7 @@ class TestServiceHelperFunctions:
|
|||
"o": {"type": "i", "iri": "http://example.org/Bob"},
|
||||
})
|
||||
|
||||
term = create_term(triple_json, otype='t')
|
||||
term = create_term(triple_json, term_type='t')
|
||||
|
||||
assert term.type == TRIPLE
|
||||
assert term.triple is not None
|
||||
|
|
|
|||
|
|
@ -96,20 +96,21 @@ class TestGraphRagResponseTranslator:
|
|||
assert is_final is False
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
# Test final chunk with empty content
|
||||
# Test final message with end_of_session=True
|
||||
final_response = GraphRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(final_response)
|
||||
|
||||
# Assert
|
||||
# Assert - is_final is based on end_of_session, not end_of_stream
|
||||
assert is_final is True
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
assert result["end_of_session"] is True
|
||||
|
||||
|
||||
class TestDocumentRagResponseTranslator:
|
||||
|
|
|
|||
|
|
@ -118,8 +118,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
|
|
@ -182,8 +182,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -219,8 +219,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -256,8 +256,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -293,8 +293,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -331,8 +331,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
assert result[0].p.value == 'all_predicate'
|
||||
assert result[0].s.iri == 'all_subject'
|
||||
assert result[0].p.iri == 'all_predicate'
|
||||
assert result[0].o.value == 'all_object'
|
||||
|
||||
def test_add_args_method(self):
|
||||
|
|
@ -637,8 +637,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -678,8 +678,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -802,7 +802,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
# Verify all results were returned
|
||||
assert len(result) == 5
|
||||
for i, triple in enumerate(result):
|
||||
assert triple.s.value == f'subject_{i}' # Mock returns literal values
|
||||
assert triple.s.iri == f'subject_{i}' # Mock returns literal values
|
||||
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
|
||||
assert triple.p.type == IRI
|
||||
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri
|
||||
|
|
|
|||
|
|
@ -540,41 +540,68 @@ class TestQuery:
|
|||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
# Call get_labelgraph
|
||||
result = await query.get_labelgraph("test query")
|
||||
|
||||
labeled_edges, uri_map = await query.get_labelgraph("test query")
|
||||
|
||||
# Verify get_subgraph was called
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
|
||||
# Verify label triples are filtered out
|
||||
assert len(result) == 2 # Label triple should be excluded
|
||||
|
||||
assert len(labeled_edges) == 2 # Label triple should be excluded
|
||||
|
||||
# Verify maybe_label was called for non-label triples
|
||||
expected_calls = [
|
||||
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
|
||||
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
|
||||
]
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
|
||||
# Verify result contains human-readable labels
|
||||
expected_result = [
|
||||
expected_edges = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert result == expected_result
|
||||
assert labeled_edges == expected_edges
|
||||
|
||||
# Verify uri_map maps labeled edges back to original URIs
|
||||
assert len(uri_map) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline"""
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance"""
|
||||
import json
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
||||
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
|
||||
# Mock prompt client response
|
||||
|
||||
# Mock prompt client responses for two-step process
|
||||
expected_response = "This is the RAG response"
|
||||
mock_prompt_client.kg_prompt.return_value = expected_response
|
||||
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
|
||||
# Compute the edge ID for the test edge
|
||||
test_edge_id = edge_id("Subject", "Predicate", "Object")
|
||||
|
||||
# Create uri_map for the test edge (maps labeled edge ID to original URIs)
|
||||
test_uri_map = {
|
||||
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
}
|
||||
|
||||
# Mock edge selection response (JSONL format)
|
||||
edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
|
||||
# Configure prompt mock to return different responses based on prompt name
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return edge_selection_response
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return expected_response
|
||||
return ""
|
||||
|
||||
mock_prompt_client.prompt = mock_prompt
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -583,39 +610,55 @@ class TestQuery:
|
|||
triples_client=mock_triples_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Mock the Query class behavior by patching get_labelgraph
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
|
||||
|
||||
# We need to patch the Query class's get_labelgraph method
|
||||
original_query_init = Query.__init__
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
|
||||
|
||||
def mock_query_init(self, *args, **kwargs):
|
||||
original_query_init(self, *args, **kwargs)
|
||||
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph
|
||||
|
||||
return test_labelgraph, test_uri_map
|
||||
|
||||
Query.__init__ = mock_query_init
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
|
||||
|
||||
# Collect provenance emitted via callback
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
try:
|
||||
# Call GraphRag.query
|
||||
result = await graph_rag.query(
|
||||
# Call GraphRag.query with provenance callback
|
||||
response = await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15
|
||||
triple_limit=15,
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Verify prompt client was called with knowledge graph and query
|
||||
mock_prompt_client.kg_prompt.assert_called_once_with("test query", test_labelgraph)
|
||||
|
||||
# Verify result
|
||||
assert result == expected_response
|
||||
|
||||
|
||||
# Verify response text
|
||||
assert response == expected_response
|
||||
|
||||
# Verify provenance was emitted incrementally (4 events: session, retrieval, selection, answer)
|
||||
assert len(provenance_events) == 4
|
||||
|
||||
# Verify each event has triples and a URN
|
||||
for triples, prov_id in provenance_events:
|
||||
assert isinstance(triples, list)
|
||||
assert len(triples) > 0
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
# Verify order: session, retrieval, selection, answer
|
||||
assert "session" in provenance_events[0][1]
|
||||
assert "retrieval" in provenance_events[1][1]
|
||||
assert "selection" in provenance_events[2][1]
|
||||
assert "answer" in provenance_events[3][1]
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
Query.__init__ = original_query_init
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Unit tests for GraphRAG service non-streaming mode.
|
||||
Tests that end_of_stream flag is correctly set in non-streaming responses.
|
||||
Unit tests for GraphRAG service message format.
|
||||
Tests the new message protocol with message_type, explain_id, and end_of_session.
|
||||
Real-time explainability emission via callback.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -11,16 +12,14 @@ from trustgraph.schema import GraphRagQuery, GraphRagResponse
|
|||
|
||||
|
||||
class TestGraphRagService:
|
||||
"""Test GraphRAG service non-streaming behavior"""
|
||||
"""Test GraphRAG service message protocol"""
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class):
|
||||
async def test_non_streaming_sends_chunk_then_provenance_messages(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that non-streaming mode sets end_of_stream=True in response.
|
||||
|
||||
This is a regression test for the bug where non-streaming responses
|
||||
didn't set end_of_stream, causing clients to hang waiting for more data.
|
||||
Test that non-streaming mode sends real-time provenance messages
|
||||
followed by chunk message with response.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
|
|
@ -32,10 +31,22 @@ class TestGraphRagService:
|
|||
max_path_length=2
|
||||
)
|
||||
|
||||
# Setup mock GraphRag instance
|
||||
# Setup mock GraphRag instance that calls explain_callback
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A small domesticated mammal."
|
||||
|
||||
# Mock query() to call the explain_callback with each provenance event
|
||||
async def mock_query(**kwargs):
|
||||
explain_callback = kwargs.get('explain_callback')
|
||||
if explain_callback:
|
||||
# Simulate real-time provenance emission
|
||||
await explain_callback([], "urn:trustgraph:session:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:selection:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:answer:test")
|
||||
return "A small domesticated mammal."
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
|
|
@ -47,7 +58,7 @@ class TestGraphRagService:
|
|||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
streaming=False # Non-streaming mode
|
||||
streaming=False
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
|
|
@ -55,30 +66,48 @@ class TestGraphRagService:
|
|||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
# Mock flow to return AsyncMock for clients and response producer
|
||||
mock_producer = AsyncMock()
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_provenance_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients
|
||||
return mock_response_producer
|
||||
elif service_name == "explainability":
|
||||
return mock_provenance_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: response was sent with end_of_stream=True
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, GraphRagResponse)
|
||||
assert sent_response.response == "A small domesticated mammal."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.error is None
|
||||
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
|
||||
assert mock_response_producer.send.call_count == 6
|
||||
|
||||
# First 4 messages are explain (emitted in real-time during query)
|
||||
for i in range(4):
|
||||
prov_msg = mock_response_producer.send.call_args_list[i][0][0]
|
||||
assert prov_msg.message_type == "explain"
|
||||
assert prov_msg.explain_id is not None
|
||||
|
||||
# 5th message is chunk with response
|
||||
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "A small domesticated mammal."
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# 6th message is empty chunk with end_of_session=True
|
||||
close_msg = mock_response_producer.send.call_args_list[5][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
|
||||
# Verify provenance triples were sent to provenance queue
|
||||
assert mock_provenance_producer.send.call_count == 4
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class):
|
||||
async def test_error_response_closes_session(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that error responses in non-streaming mode set end_of_stream=True.
|
||||
Test that error responses set end_of_session=True.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
|
|
@ -105,7 +134,7 @@ class TestGraphRagService:
|
|||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
streaming=False # Non-streaming mode
|
||||
streaming=False
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
|
|
@ -113,22 +142,93 @@ class TestGraphRagService:
|
|||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_provenance_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return mock_response_producer
|
||||
elif service_name == "explainability":
|
||||
return mock_provenance_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: error response was sent without end_of_stream (not streaming mode)
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
# Verify: error response was sent with session closed
|
||||
mock_response_producer.send.assert_called_once()
|
||||
sent_response = mock_response_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, GraphRagResponse)
|
||||
assert sent_response.response is None
|
||||
assert sent_response.message_type == "chunk"
|
||||
assert sent_response.error is not None
|
||||
assert sent_response.error.message == "Test error"
|
||||
# Note: error responses in non-streaming mode don't set end_of_stream
|
||||
# because streaming was never started
|
||||
assert sent_response.end_of_stream is True
|
||||
assert sent_response.end_of_session is True
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_provenance_sends_empty_chunk_to_close(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that when no provenance callback is invoked, an empty chunk closes the session.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2
|
||||
)
|
||||
|
||||
# Setup mock GraphRag instance that doesn't call provenance callback
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
# Don't call explain_callback
|
||||
return "Response text"
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
# Setup message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="Test query",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_provenance_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_response_producer
|
||||
elif service_name == "explainability":
|
||||
return mock_provenance_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 2 messages (chunk + empty chunk to close)
|
||||
assert mock_response_producer.send.call_count == 2
|
||||
|
||||
# First is the response chunk
|
||||
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "Response text"
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# Second is empty chunk to close session
|
||||
close_msg = mock_response_producer.send.call_args_list[1][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue