mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Fix non streaming RAG problems (#607)
* Fix non-streaming failure in RAG services * Fix non-streaming failure in API * Fix agent non-streaming messaging * Agent messaging unit & contract tests
This commit is contained in:
parent
30ca1d2e8b
commit
807f6cc4e2
10 changed files with 677 additions and 21 deletions
|
|
@ -74,4 +74,58 @@ class TestDocumentRagService:
|
|||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, DocumentRagResponse)
|
||||
assert sent_response.response == "test response"
|
||||
assert sent_response.error is None
|
||||
|
||||
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_document_rag_class):
|
||||
"""
|
||||
Test that non-streaming mode sets end_of_stream=True in response.
|
||||
|
||||
This is a regression test for the bug where non-streaming responses
|
||||
didn't set end_of_stream, causing clients to hang waiting for more data.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
doc_limit=10
|
||||
)
|
||||
|
||||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A document about cats."
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = DocumentRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
doc_limit=10,
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: response was sent with end_of_stream=True
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, DocumentRagResponse)
|
||||
assert sent_response.response == "A document about cats."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.error is None
|
||||
134
tests/unit/test_retrieval/test_graph_rag_service.py
Normal file
134
tests/unit/test_retrieval/test_graph_rag_service.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Unit tests for GraphRAG service non-streaming mode.
|
||||
Tests that end_of_stream flag is correctly set in non-streaming responses.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.retrieval.graph_rag.rag import Processor
|
||||
from trustgraph.schema import GraphRagQuery, GraphRagResponse
|
||||
|
||||
|
||||
class TestGraphRagService:
|
||||
"""Test GraphRAG service non-streaming behavior"""
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that non-streaming mode sets end_of_stream=True in response.
|
||||
|
||||
This is a regression test for the bug where non-streaming responses
|
||||
didn't set end_of_stream, causing clients to hang waiting for more data.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2
|
||||
)
|
||||
|
||||
# Setup mock GraphRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A small domesticated mammal."
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
# Mock flow to return AsyncMock for clients and response producer
|
||||
mock_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: response was sent with end_of_stream=True
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, GraphRagResponse)
|
||||
assert sent_response.response == "A small domesticated mammal."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.error is None
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that error responses in non-streaming mode set end_of_stream=True.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2
|
||||
)
|
||||
|
||||
# Setup mock GraphRag instance that raises an exception
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.side_effect = Exception("Test error")
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: error response was sent without end_of_stream (not streaming mode)
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, GraphRagResponse)
|
||||
assert sent_response.response is None
|
||||
assert sent_response.error is not None
|
||||
assert sent_response.error.message == "Test error"
|
||||
# Note: error responses in non-streaming mode don't set end_of_stream
|
||||
# because streaming was never started
|
||||
Loading…
Add table
Add a link
Reference in a new issue