Fix non streaming RAG problems (#607)

* Fix non-streaming failure in RAG services

* Fix non-streaming failure in API

* Fix agent non-streaming messaging

* Agent messaging unit & contract tests
This commit is contained in:
cybermaggedon 2026-01-12 18:45:52 +00:00 committed by GitHub
parent 30ca1d2e8b
commit 807f6cc4e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 677 additions and 21 deletions

View file

@ -0,0 +1,206 @@
"""
Unit tests for Agent service non-streaming mode.
Tests that end_of_message and end_of_dialog flags are correctly set.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.agent.react.service import Processor
from trustgraph.schema import AgentRequest, AgentResponse
from trustgraph.agent.react.types import Final
class TestAgentServiceNonStreaming:
"""Test Agent service non-streaming behavior"""
@patch('trustgraph.agent.react.service.AgentManager')
@pytest.mark.asyncio
async def test_non_streaming_intermediate_messages_have_correct_flags(self, mock_agent_manager_class):
"""
Test that intermediate messages (thought/observation) in non-streaming mode
have end_of_message=True and end_of_dialog=False.
"""
# Setup processor
processor = Processor(
taskgroup=MagicMock(),
id="test-agent",
max_iterations=10
)
# Track all responses sent
sent_responses = []
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
# Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming):
await think("I need to solve this.", is_final=True)
await observe("The answer is 4.", is_final=True)
return Final(thought="Final answer", final="4")
mock_agent_instance.react = mock_react
# Setup message with non-streaming request
msg = MagicMock()
msg.value.return_value = AgentRequest(
question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
mock_producer = AsyncMock()
async def capture_response(response, properties):
sent_responses.append(response)
mock_producer.send = AsyncMock(side_effect=capture_response)
def flow_router(service_name):
if service_name == "response":
return mock_producer
return AsyncMock()
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: should have 3 responses (thought, observation, answer)
assert len(sent_responses) == 3, f"Expected 3 responses, got {len(sent_responses)}"
# Check thought message
thought_response = sent_responses[0]
assert isinstance(thought_response, AgentResponse)
assert thought_response.thought == "I need to solve this."
assert thought_response.answer is None
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
# Check observation message
observation_response = sent_responses[1]
assert isinstance(observation_response, AgentResponse)
assert observation_response.observation == "The answer is 4."
assert observation_response.answer is None
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
@patch('trustgraph.agent.react.service.AgentManager')
@pytest.mark.asyncio
async def test_non_streaming_final_answer_has_correct_flags(self, mock_agent_manager_class):
"""
Test that final answer in non-streaming mode has
end_of_message=True and end_of_dialog=True.
"""
# Setup processor
processor = Processor(
taskgroup=MagicMock(),
id="test-agent",
max_iterations=10
)
# Track all responses sent
sent_responses = []
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
# Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming):
return Final(thought="Final answer", final="4")
mock_agent_instance.react = mock_react
# Setup message with non-streaming request
msg = MagicMock()
msg.value.return_value = AgentRequest(
question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
mock_producer = AsyncMock()
async def capture_response(response, properties):
sent_responses.append(response)
mock_producer.send = AsyncMock(side_effect=capture_response)
def flow_router(service_name):
if service_name == "response":
return mock_producer
return AsyncMock()
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: should have 1 response (final answer)
assert len(sent_responses) == 1, f"Expected 1 response, got {len(sent_responses)}"
# Check final answer message
answer_response = sent_responses[0]
assert isinstance(answer_response, AgentResponse)
assert answer_response.answer == "4"
assert answer_response.thought is None
assert answer_response.observation is None
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"
@pytest.mark.asyncio
async def test_error_response_has_correct_flags(self):
"""
Test that error responses have end_of_message=True and end_of_dialog=True.
"""
# Setup processor that will error
processor = Processor(
taskgroup=MagicMock(),
id="test-agent",
max_iterations=10
)
# Track all responses sent
sent_responses = []
# Setup message
msg = MagicMock()
msg.value.side_effect = Exception("Test error")
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
async def capture_response(response, properties):
sent_responses.append(response)
flow.producer["response"].send = AsyncMock(side_effect=capture_response)
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: should have 1 error response
assert len(sent_responses) == 1, f"Expected 1 error response, got {len(sent_responses)}"
# Check error response
error_response = sent_responses[0]
assert isinstance(error_response, AgentResponse)
assert error_response.error is not None
assert "Test error" in error_response.error.message
assert error_response.end_of_message is True, "Error response must have end_of_message=True"
assert error_response.end_of_dialog is True, "Error response must have end_of_dialog=True"

View file

@ -74,4 +74,58 @@ class TestDocumentRagService:
sent_response = mock_producer.send.call_args[0][0]
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "test response"
assert sent_response.error is None
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
@pytest.mark.asyncio
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_document_rag_class):
"""
Test that non-streaming mode sets end_of_stream=True in response.
This is a regression test for the bug where non-streaming responses
didn't set end_of_stream, causing clients to hang waiting for more data.
"""
# Setup processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
doc_limit=10
)
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "A document about cats."
# Setup message with non-streaming request
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
doc_limit=10,
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
mock_producer = AsyncMock()
def flow_router(service_name):
if service_name == "response":
return mock_producer
return AsyncMock()
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: response was sent with end_of_stream=True
mock_producer.send.assert_called_once()
sent_response = mock_producer.send.call_args[0][0]
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "A document about cats."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.error is None

View file

@ -0,0 +1,134 @@
"""
Unit tests for GraphRAG service non-streaming mode.
Tests that end_of_stream flag is correctly set in non-streaming responses.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.retrieval.graph_rag.rag import Processor
from trustgraph.schema import GraphRagQuery, GraphRagResponse
class TestGraphRagService:
"""Test GraphRAG service non-streaming behavior"""
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
@pytest.mark.asyncio
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class):
"""
Test that non-streaming mode sets end_of_stream=True in response.
This is a regression test for the bug where non-streaming responses
didn't set end_of_stream, causing clients to hang waiting for more data.
"""
# Setup processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
entity_limit=50,
triple_limit=30,
max_subgraph_size=150,
max_path_length=2
)
# Setup mock GraphRag instance
mock_rag_instance = AsyncMock()
mock_graph_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "A small domesticated mammal."
# Setup message with non-streaming request
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
entity_limit=50,
triple_limit=30,
max_subgraph_size=150,
max_path_length=2,
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
# Mock flow to return AsyncMock for clients and response producer
mock_producer = AsyncMock()
def flow_router(service_name):
if service_name == "response":
return mock_producer
return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: response was sent with end_of_stream=True
mock_producer.send.assert_called_once()
sent_response = mock_producer.send.call_args[0][0]
assert isinstance(sent_response, GraphRagResponse)
assert sent_response.response == "A small domesticated mammal."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.error is None
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
@pytest.mark.asyncio
async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class):
"""
Test that error responses in non-streaming mode set end_of_stream=True.
"""
# Setup processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
entity_limit=50,
triple_limit=30,
max_subgraph_size=150,
max_path_length=2
)
# Setup mock GraphRag instance that raises an exception
mock_rag_instance = AsyncMock()
mock_graph_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.side_effect = Exception("Test error")
# Setup message with non-streaming request
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
entity_limit=50,
triple_limit=30,
max_subgraph_size=150,
max_path_length=2,
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
mock_producer = AsyncMock()
def flow_router(service_name):
if service_name == "response":
return mock_producer
return AsyncMock()
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: error response was sent without end_of_stream (not streaming mode)
mock_producer.send.assert_called_once()
sent_response = mock_producer.send.call_args[0][0]
assert isinstance(sent_response, GraphRagResponse)
assert sent_response.response is None
assert sent_response.error is not None
assert sent_response.error.message == "Test error"
# Note: error responses in non-streaming mode don't set end_of_stream
# because streaming was never started