mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 02:46:23 +02:00
Update to add streaming tests (#600)
This commit is contained in:
parent
f0c95a4c5e
commit
f79d0603f7
9 changed files with 1062 additions and 57 deletions
|
|
@ -102,7 +102,7 @@ async def test_handle_normal_flow():
|
|||
"""Test normal websocket handling flow."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
|
||||
dispatcher_created = False
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
nonlocal dispatcher_created
|
||||
|
|
@ -110,33 +110,41 @@ async def test_handle_normal_flow():
|
|||
dispatcher = AsyncMock()
|
||||
dispatcher.destroy = AsyncMock()
|
||||
return dispatcher
|
||||
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
# Mock task group context manager
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
||||
|
||||
# Create proper mock tasks that look like asyncio.Task objects
|
||||
def create_task_mock(coro):
|
||||
task = AsyncMock()
|
||||
task.done = MagicMock(return_value=True)
|
||||
task.cancelled = MagicMock(return_value=False)
|
||||
return task
|
||||
|
||||
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
|
||||
# Should have created dispatcher
|
||||
assert dispatcher_created is True
|
||||
|
||||
|
||||
# Should return websocket
|
||||
assert result == mock_ws
|
||||
|
||||
|
|
@ -146,50 +154,58 @@ async def test_handle_exception_group_cleanup():
|
|||
"""Test exception group triggers dispatcher cleanup."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
|
||||
# Mock TaskGroup to raise ExceptionGroup
|
||||
class TestException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
|
||||
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||
mock_tg.create_task = MagicMock(side_effect=TestException("test"))
|
||||
|
||||
# Create proper mock tasks that look like asyncio.Task objects
|
||||
def create_task_mock(coro):
|
||||
task = AsyncMock()
|
||||
task.done = MagicMock(return_value=True)
|
||||
task.cancelled = MagicMock(return_value=False)
|
||||
return task
|
||||
|
||||
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
||||
mock_wait_for.return_value = None
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
|
||||
# Should have attempted graceful cleanup
|
||||
mock_wait_for.assert_called_once()
|
||||
|
||||
|
||||
# Should have called destroy in finally block
|
||||
assert mock_dispatcher.destroy.call_count >= 1
|
||||
|
||||
|
||||
# Should have closed websocket
|
||||
mock_ws.close.assert_called()
|
||||
|
||||
|
|
@ -199,48 +215,56 @@ async def test_handle_dispatcher_cleanup_timeout():
|
|||
"""Test dispatcher cleanup with timeout."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
|
||||
# Mock dispatcher that takes long to destroy
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
|
||||
# Mock TaskGroup to raise exception
|
||||
exception_group = ExceptionGroup("Test", [Exception("test")])
|
||||
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||
mock_tg.create_task = MagicMock(side_effect=Exception("test"))
|
||||
|
||||
# Create proper mock tasks that look like asyncio.Task objects
|
||||
def create_task_mock(coro):
|
||||
task = AsyncMock()
|
||||
task.done = MagicMock(return_value=True)
|
||||
task.cancelled = MagicMock(return_value=False)
|
||||
return task
|
||||
|
||||
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
|
||||
# Mock asyncio.wait_for to raise TimeoutError
|
||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
||||
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
|
||||
# Should have attempted cleanup with timeout
|
||||
mock_wait_for.assert_called_once()
|
||||
# Check that timeout was passed correctly
|
||||
assert mock_wait_for.call_args[1]['timeout'] == 5.0
|
||||
|
||||
|
||||
# Should still call destroy in finally block
|
||||
assert mock_dispatcher.destroy.call_count >= 1
|
||||
|
||||
|
|
@ -290,37 +314,45 @@ async def test_handle_websocket_already_closed():
|
|||
"""Test handling when websocket is already closed."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = True # Already closed
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
||||
|
||||
# Create proper mock tasks that look like asyncio.Task objects
|
||||
def create_task_mock(coro):
|
||||
task = AsyncMock()
|
||||
task.done = MagicMock(return_value=True)
|
||||
task.cancelled = MagicMock(return_value=False)
|
||||
return task
|
||||
|
||||
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
|
||||
# Should still have called destroy
|
||||
mock_dispatcher.destroy.assert_called()
|
||||
|
||||
|
||||
# Should not attempt to close already closed websocket
|
||||
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True
|
||||
326
tests/unit/test_gateway/test_streaming_translators.py
Normal file
326
tests/unit/test_gateway/test_streaming_translators.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
"""
|
||||
Unit tests for streaming behavior in message translators.
|
||||
|
||||
These tests verify that translators correctly handle empty strings and
|
||||
end_of_stream flags in streaming responses, preventing bugs where empty
|
||||
final chunks could be dropped due to falsy value checks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from trustgraph.messaging.translators.retrieval import (
|
||||
GraphRagResponseTranslator,
|
||||
DocumentRagResponseTranslator,
|
||||
)
|
||||
from trustgraph.messaging.translators.prompt import PromptResponseTranslator
|
||||
from trustgraph.messaging.translators.text_completion import TextCompletionResponseTranslator
|
||||
from trustgraph.schema import (
|
||||
GraphRagResponse,
|
||||
DocumentRagResponse,
|
||||
PromptResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphRagResponseTranslator:
|
||||
"""Test GraphRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
response = GraphRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert - Empty string should be included in result
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
response = GraphRagResponse(
|
||||
response="Some text",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Some text"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_response(self):
|
||||
"""Test that None response is handled correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
response = GraphRagResponse(
|
||||
response=None,
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert - None should not be included
|
||||
assert "response" not in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_response_with_completion_returns_correct_flag(self):
|
||||
"""Test that from_response_with_completion returns correct is_final flag"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
# Test non-final chunk
|
||||
response_chunk = GraphRagResponse(
|
||||
response="chunk",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response_chunk)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
# Test final chunk with empty content
|
||||
final_response = GraphRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(final_response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
|
||||
class TestDocumentRagResponseTranslator:
|
||||
"""Test DocumentRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
response = DocumentRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
response = DocumentRagResponse(
|
||||
response="Document content",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Document content"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
|
||||
class TestPromptResponseTranslator:
|
||||
"""Test PromptResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_text(self):
|
||||
"""Test that empty text strings are preserved"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
response = PromptResponse(
|
||||
text="",
|
||||
object=None,
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "text" in result
|
||||
assert result["text"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_text(self):
|
||||
"""Test that non-empty text works correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
response = PromptResponse(
|
||||
text="Some prompt response",
|
||||
object=None,
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert result["text"] == "Some prompt response"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_text(self):
|
||||
"""Test that None text is handled correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
response = PromptResponse(
|
||||
text=None,
|
||||
object='{"result": "data"}',
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "text" not in result
|
||||
assert "object" in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_includes_end_of_stream(self):
|
||||
"""Test that end_of_stream flag is always included"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
||||
# Test with end_of_stream=False
|
||||
response = PromptResponse(
|
||||
text="chunk",
|
||||
object=None,
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
|
||||
class TestTextCompletionResponseTranslator:
|
||||
"""Test TextCompletionResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_always_includes_response(self):
|
||||
"""Test that response field is always included, even if empty"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
response = TextCompletionResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
in_token=100,
|
||||
out_token=5,
|
||||
model="test-model"
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert - Response should always be present
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
|
||||
def test_from_response_with_completion_with_empty_final(self):
|
||||
"""Test that empty final response is handled correctly"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
response = TextCompletionResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
in_token=100,
|
||||
out_token=5,
|
||||
model="test-model"
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True
|
||||
assert result["response"] == ""
|
||||
|
||||
|
||||
class TestStreamingProtocolCompliance:
|
||||
"""Test that all translators follow streaming protocol conventions"""
|
||||
|
||||
@pytest.mark.parametrize("translator_class,response_class,field_name", [
|
||||
(GraphRagResponseTranslator, GraphRagResponse, "response"),
|
||||
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
|
||||
(PromptResponseTranslator, PromptResponse, "text"),
|
||||
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
|
||||
])
|
||||
def test_empty_final_chunk_preserved(self, translator_class, response_class, field_name):
|
||||
"""Test that all translators preserve empty final chunks"""
|
||||
# Arrange
|
||||
translator = translator_class()
|
||||
kwargs = {
|
||||
field_name: "",
|
||||
"end_of_stream": True,
|
||||
"error": None,
|
||||
}
|
||||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
|
||||
assert result[field_name] == "", f"{translator_class.__name__} should preserve empty string"
|
||||
|
||||
@pytest.mark.parametrize("translator_class,response_class,field_name", [
|
||||
(GraphRagResponseTranslator, GraphRagResponse, "response"),
|
||||
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
|
||||
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
|
||||
])
|
||||
def test_end_of_stream_flag_included(self, translator_class, response_class, field_name):
|
||||
"""Test that end_of_stream flag is included in all response translators"""
|
||||
# Arrange
|
||||
translator = translator_class()
|
||||
kwargs = {
|
||||
field_name: "test content",
|
||||
"end_of_stream": True,
|
||||
"error": None,
|
||||
}
|
||||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
|
||||
assert result["end_of_stream"] is True
|
||||
Loading…
Add table
Add a link
Reference in a new issue