Update to add streaming tests (#600)

This commit is contained in:
cybermaggedon 2026-01-06 21:48:05 +00:00 committed by GitHub
parent f0c95a4c5e
commit f79d0603f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 1062 additions and 57 deletions

View file

@ -0,0 +1,260 @@
"""
Unit tests for PromptClient streaming callback behavior.
These tests verify that the prompt client correctly passes the end_of_stream
flag to chunk callbacks, ensuring proper streaming protocol compliance.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, call, patch
from trustgraph.base.prompt_client import PromptClient
from trustgraph.schema import PromptResponse
class TestPromptClientStreamingCallback:
"""Test PromptClient streaming callback behavior"""
@pytest.fixture
def prompt_client(self):
"""Create a PromptClient with mocked dependencies"""
# Mock all the required initialization parameters
with patch.object(PromptClient, '__init__', lambda self: None):
client = PromptClient()
return client
@pytest.fixture
def mock_request_response(self):
"""Create a mock request/response handler"""
async def mock_request(request, recipient=None, timeout=600):
if recipient:
# Simulate streaming responses
responses = [
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
PromptResponse(text=" world", object=None, error=None, end_of_stream=False),
PromptResponse(text="!", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
else:
# Non-streaming response
return PromptResponse(text="Hello world!", object=None, error=None)
return mock_request
@pytest.mark.asyncio
async def test_callback_receives_chunk_and_end_of_stream(self, prompt_client, mock_request_response):
"""Test that callback receives both chunk text and end_of_stream flag"""
# Arrange
prompt_client.request = mock_request_response
callback = AsyncMock()
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=callback
)
# Assert - callback should be called with (chunk, end_of_stream) signature
assert callback.call_count == 4
# Verify first chunk: text + end_of_stream=False
assert callback.call_args_list[0] == call("Hello", False)
# Verify second chunk
assert callback.call_args_list[1] == call(" world", False)
# Verify third chunk
assert callback.call_args_list[2] == call("!", False)
# Verify final chunk: empty text + end_of_stream=True
assert callback.call_args_list[3] == call("", True)
@pytest.mark.asyncio
async def test_callback_receives_empty_final_chunk(self, prompt_client, mock_request_response):
"""Test that empty final chunks are passed to callback"""
# Arrange
prompt_client.request = mock_request_response
chunks_received = []
async def collect_chunks(chunk, end_of_stream):
chunks_received.append((chunk, end_of_stream))
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=collect_chunks
)
# Assert - should receive the empty final chunk
final_chunk = chunks_received[-1]
assert final_chunk == ("", True), "Final chunk should be empty string with end_of_stream=True"
@pytest.mark.asyncio
async def test_callback_signature_with_non_empty_final_chunk(self, prompt_client):
"""Test callback signature when LLM sends non-empty final chunk"""
# Arrange
async def mock_request_non_empty_final(request, recipient=None, timeout=600):
if recipient:
# Some LLMs send content in the final chunk
responses = [
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
PromptResponse(text=" world!", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request_non_empty_final
callback = AsyncMock()
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Hello", False)
assert callback.call_args_list[1] == call(" world!", True)
@pytest.mark.asyncio
async def test_callback_not_called_without_text(self, prompt_client):
"""Test that callback is not called for responses without text"""
# Arrange
async def mock_request_no_text(request, recipient=None, timeout=600):
if recipient:
# Response with only end_of_stream, no text
responses = [
PromptResponse(text="Content", object=None, error=None, end_of_stream=False),
PromptResponse(text=None, object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request_no_text
callback = AsyncMock()
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=callback
)
# Assert - callback should only be called once (for "Content")
assert callback.call_count == 1
assert callback.call_args_list[0] == call("Content", False)
@pytest.mark.asyncio
async def test_synchronous_callback_also_receives_end_of_stream(self, prompt_client):
"""Test that synchronous callbacks also receive end_of_stream parameter"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="test", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = MagicMock() # Synchronous mock
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=callback
)
# Assert - synchronous callback should also get both parameters
assert callback.call_count == 2
assert callback.call_args_list[0] == call("test", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that kg_prompt correctly passes streaming parameters"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = AsyncMock()
# Act
await prompt_client.kg_prompt(
query="What is machine learning?",
kg=[("subject", "predicate", "object")],
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Answer", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that document_prompt correctly passes streaming parameters"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="Summary", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = AsyncMock()
# Act
await prompt_client.document_prompt(
query="Summarize this",
documents=["doc1", "doc2"],
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Summary", False)
assert callback.call_args_list[1] == call("", True)

View file

@ -102,7 +102,7 @@ async def test_handle_normal_flow():
"""Test normal websocket handling flow."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
dispatcher_created = False
async def mock_dispatcher_factory(ws, running, match_info):
nonlocal dispatcher_created
@ -110,33 +110,41 @@ async def test_handle_normal_flow():
dispatcher = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
# Mock task group context manager
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None)
mock_tg.create_task = MagicMock(return_value=AsyncMock())
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request)
# Should have created dispatcher
assert dispatcher_created is True
# Should return websocket
assert result == mock_ws
@ -146,50 +154,58 @@ async def test_handle_exception_group_cleanup():
"""Test exception group triggers dispatcher cleanup."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
# Mock TaskGroup to raise ExceptionGroup
class TestException(Exception):
pass
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
mock_tg.create_task = MagicMock(side_effect=TestException("test"))
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
mock_wait_for.return_value = None
result = await socket_endpoint.handle(request)
# Should have attempted graceful cleanup
mock_wait_for.assert_called_once()
# Should have called destroy in finally block
assert mock_dispatcher.destroy.call_count >= 1
# Should have closed websocket
mock_ws.close.assert_called()
@ -199,48 +215,56 @@ async def test_handle_dispatcher_cleanup_timeout():
"""Test dispatcher cleanup with timeout."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
# Mock dispatcher that takes long to destroy
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
# Mock TaskGroup to raise exception
exception_group = ExceptionGroup("Test", [Exception("test")])
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
mock_tg.create_task = MagicMock(side_effect=Exception("test"))
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
# Mock asyncio.wait_for to raise TimeoutError
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
result = await socket_endpoint.handle(request)
# Should have attempted cleanup with timeout
mock_wait_for.assert_called_once()
# Check that timeout was passed correctly
assert mock_wait_for.call_args[1]['timeout'] == 5.0
# Should still call destroy in finally block
assert mock_dispatcher.destroy.call_count >= 1
@ -290,37 +314,45 @@ async def test_handle_websocket_already_closed():
"""Test handling when websocket is already closed."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = True # Already closed
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None)
mock_tg.create_task = MagicMock(return_value=AsyncMock())
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request)
# Should still have called destroy
mock_dispatcher.destroy.assert_called()
# Should not attempt to close already closed websocket
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True

View file

@ -0,0 +1,326 @@
"""
Unit tests for streaming behavior in message translators.
These tests verify that translators correctly handle empty strings and
end_of_stream flags in streaming responses, preventing bugs where empty
final chunks could be dropped due to falsy value checks.
"""
import pytest
from unittest.mock import MagicMock
from trustgraph.messaging.translators.retrieval import (
GraphRagResponseTranslator,
DocumentRagResponseTranslator,
)
from trustgraph.messaging.translators.prompt import PromptResponseTranslator
from trustgraph.messaging.translators.text_completion import TextCompletionResponseTranslator
from trustgraph.schema import (
GraphRagResponse,
DocumentRagResponse,
PromptResponse,
TextCompletionResponse,
)
class TestGraphRagResponseTranslator:
"""Test GraphRagResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_response(self):
"""Test that empty response strings are preserved"""
# Arrange
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
response="",
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert - Empty string should be included in result
assert "response" in result
assert result["response"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_response(self):
"""Test that non-empty responses work correctly"""
# Arrange
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
response="Some text",
end_of_stream=False,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert result["response"] == "Some text"
assert result["end_of_stream"] is False
def test_from_pulsar_with_none_response(self):
"""Test that None response is handled correctly"""
# Arrange
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
response=None,
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert - None should not be included
assert "response" not in result
assert result["end_of_stream"] is True
def test_from_response_with_completion_returns_correct_flag(self):
"""Test that from_response_with_completion returns correct is_final flag"""
# Arrange
translator = GraphRagResponseTranslator()
# Test non-final chunk
response_chunk = GraphRagResponse(
response="chunk",
end_of_stream=False,
error=None
)
# Act
result, is_final = translator.from_response_with_completion(response_chunk)
# Assert
assert is_final is False
assert result["end_of_stream"] is False
# Test final chunk with empty content
final_response = GraphRagResponse(
response="",
end_of_stream=True,
error=None
)
# Act
result, is_final = translator.from_response_with_completion(final_response)
# Assert
assert is_final is True
assert result["response"] == ""
assert result["end_of_stream"] is True
class TestDocumentRagResponseTranslator:
"""Test DocumentRagResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_response(self):
"""Test that empty response strings are preserved"""
# Arrange
translator = DocumentRagResponseTranslator()
response = DocumentRagResponse(
response="",
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert "response" in result
assert result["response"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_response(self):
"""Test that non-empty responses work correctly"""
# Arrange
translator = DocumentRagResponseTranslator()
response = DocumentRagResponse(
response="Document content",
end_of_stream=False,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert result["response"] == "Document content"
assert result["end_of_stream"] is False
class TestPromptResponseTranslator:
"""Test PromptResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_text(self):
"""Test that empty text strings are preserved"""
# Arrange
translator = PromptResponseTranslator()
response = PromptResponse(
text="",
object=None,
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert "text" in result
assert result["text"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_text(self):
"""Test that non-empty text works correctly"""
# Arrange
translator = PromptResponseTranslator()
response = PromptResponse(
text="Some prompt response",
object=None,
end_of_stream=False,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert result["text"] == "Some prompt response"
assert result["end_of_stream"] is False
def test_from_pulsar_with_none_text(self):
"""Test that None text is handled correctly"""
# Arrange
translator = PromptResponseTranslator()
response = PromptResponse(
text=None,
object='{"result": "data"}',
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert "text" not in result
assert "object" in result
assert result["end_of_stream"] is True
def test_from_pulsar_includes_end_of_stream(self):
"""Test that end_of_stream flag is always included"""
# Arrange
translator = PromptResponseTranslator()
# Test with end_of_stream=False
response = PromptResponse(
text="chunk",
object=None,
end_of_stream=False,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert "end_of_stream" in result
assert result["end_of_stream"] is False
class TestTextCompletionResponseTranslator:
"""Test TextCompletionResponseTranslator streaming behavior"""
def test_from_pulsar_always_includes_response(self):
"""Test that response field is always included, even if empty"""
# Arrange
translator = TextCompletionResponseTranslator()
response = TextCompletionResponse(
response="",
end_of_stream=True,
error=None,
in_token=100,
out_token=5,
model="test-model"
)
# Act
result = translator.from_pulsar(response)
# Assert - Response should always be present
assert "response" in result
assert result["response"] == ""
def test_from_response_with_completion_with_empty_final(self):
"""Test that empty final response is handled correctly"""
# Arrange
translator = TextCompletionResponseTranslator()
response = TextCompletionResponse(
response="",
end_of_stream=True,
error=None,
in_token=100,
out_token=5,
model="test-model"
)
# Act
result, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is True
assert result["response"] == ""
class TestStreamingProtocolCompliance:
"""Test that all translators follow streaming protocol conventions"""
@pytest.mark.parametrize("translator_class,response_class,field_name", [
(GraphRagResponseTranslator, GraphRagResponse, "response"),
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
(PromptResponseTranslator, PromptResponse, "text"),
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
])
def test_empty_final_chunk_preserved(self, translator_class, response_class, field_name):
"""Test that all translators preserve empty final chunks"""
# Arrange
translator = translator_class()
kwargs = {
field_name: "",
"end_of_stream": True,
"error": None,
}
response = response_class(**kwargs)
# Act
result = translator.from_pulsar(response)
# Assert
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
assert result[field_name] == "", f"{translator_class.__name__} should preserve empty string"
@pytest.mark.parametrize("translator_class,response_class,field_name", [
(GraphRagResponseTranslator, GraphRagResponse, "response"),
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
])
def test_end_of_stream_flag_included(self, translator_class, response_class, field_name):
"""Test that end_of_stream flag is included in all response translators"""
# Arrange
translator = translator_class()
kwargs = {
field_name: "test content",
"end_of_stream": True,
"error": None,
}
response = response_class(**kwargs)
# Act
result = translator.from_pulsar(response)
# Assert
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
assert result["end_of_stream"] is True