diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index af5dda5b..7e18f0de 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -480,11 +480,15 @@ def streaming_chunk_collector(): class ChunkCollector: def __init__(self): self.chunks = [] + self.end_of_stream_flags = [] self.complete = False - async def collect(self, chunk): - """Async callback to collect chunks""" + async def collect(self, chunk, end_of_stream=False): + """Async callback to collect chunks with end_of_stream flag""" self.chunks.append(chunk) + self.end_of_stream_flags.append(end_of_stream) + if end_of_stream: + self.complete = True def get_full_text(self): """Concatenate all chunk content""" @@ -496,6 +500,14 @@ def streaming_chunk_collector(): return [c.get("chunk_type") for c in self.chunks] return [] + def verify_streaming_protocol(self): + """Verify that streaming protocol is correct""" + assert len(self.chunks) > 0, "Should have received at least one chunk" + assert len(self.chunks) == len(self.end_of_stream_flags), "Each chunk should have an end_of_stream flag" + assert self.end_of_stream_flags.count(True) == 1, "Exactly one chunk should have end_of_stream=True" + assert self.end_of_stream_flags[-1] is True, "Last chunk should have end_of_stream=True" + assert self.complete is True, "Should be marked complete after final chunk" + return ChunkCollector diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index 4b792443..84061add 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -46,9 +46,16 @@ class TestDocumentRagStreaming: full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." if streaming and chunk_callback: - # Simulate streaming chunks + # Simulate streaming chunks with end_of_stream flags + chunks = [] async for chunk in mock_streaming_llm_response(): - await chunk_callback(chunk) + chunks.append(chunk) + + # Send all chunks with end_of_stream=False except the last + for i, chunk in enumerate(chunks): + is_final = (i == len(chunks) - 1) + await chunk_callback(chunk, is_final) + return full_text else: # Non-streaming response - same text @@ -89,6 +96,9 @@ class TestDocumentRagStreaming: assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) + # Verify streaming protocol compliance + collector.verify_streaming_protocol() + # Verify full response matches concatenated chunks full_from_chunks = collector.get_full_text() assert result == full_from_chunks @@ -117,7 +127,7 @@ class TestDocumentRagStreaming: # Act - Streaming streaming_chunks = [] - async def collect(chunk): + async def collect(chunk, end_of_stream): streaming_chunks.append(chunk) streaming_result = await document_rag_streaming.query( diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 92da6527..47dd84b6 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -59,9 +59,16 @@ class TestGraphRagStreaming: full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." if streaming and chunk_callback: - # Simulate streaming chunks + # Simulate streaming chunks with end_of_stream flags + chunks = [] async for chunk in mock_streaming_llm_response(): - await chunk_callback(chunk) + chunks.append(chunk) + + # Send all chunks with end_of_stream=False except the last + for i, chunk in enumerate(chunks): + is_final = (i == len(chunks) - 1) + await chunk_callback(chunk, is_final) + return full_text else: # Non-streaming response - same text @@ -102,6 +109,9 @@ class TestGraphRagStreaming: assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) + # Verify streaming protocol compliance + collector.verify_streaming_protocol() + # Verify full response matches concatenated chunks full_from_chunks = collector.get_full_text() assert result == full_from_chunks @@ -128,7 +138,7 @@ class TestGraphRagStreaming: # Act - Streaming streaming_chunks = [] - async def collect(chunk): + async def collect(chunk, end_of_stream): streaming_chunks.append(chunk) streaming_result = await graph_rag_streaming.query( diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py new file mode 100644 index 00000000..d2ceea95 --- /dev/null +++ b/tests/integration/test_rag_streaming_protocol.py @@ -0,0 +1,351 @@ +""" +Integration tests for RAG service streaming protocol compliance. + +These tests verify that RAG services correctly forward end_of_stream flags +and don't duplicate final chunks, ensuring proper streaming semantics. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call +from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.retrieval.document_rag.document_rag import DocumentRag + + +class TestGraphRagStreamingProtocol: + """Integration tests for GraphRAG streaming protocol""" + + @pytest.fixture + def mock_embeddings_client(self): + """Mock embeddings client""" + client = AsyncMock() + client.embed.return_value = [[0.1, 0.2, 0.3]] + return client + + @pytest.fixture + def mock_graph_embeddings_client(self): + """Mock graph embeddings client""" + client = AsyncMock() + client.query.return_value = ["entity1", "entity2"] + return client + + @pytest.fixture + def mock_triples_client(self): + """Mock triples client""" + client = AsyncMock() + client.query.return_value = [] + return client + + @pytest.fixture + def mock_streaming_prompt_client(self): + """Mock prompt client that simulates realistic streaming with end_of_stream flags""" + client = AsyncMock() + + async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None): + if streaming and chunk_callback: + # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True + await chunk_callback("The", False) + await chunk_callback(" answer", False) + await chunk_callback(" is here.", False) + await chunk_callback("", True) # Empty final chunk with end_of_stream=True + return "" # Return value not used since callback handles everything + else: + return "The answer is here." + + client.kg_prompt.side_effect = kg_prompt_side_effect + return client + + @pytest.fixture + def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client, + mock_triples_client, mock_streaming_prompt_client): + """Create GraphRag instance with mocked dependencies""" + return GraphRag( + embeddings_client=mock_embeddings_client, + graph_embeddings_client=mock_graph_embeddings_client, + triples_client=mock_triples_client, + prompt_client=mock_streaming_prompt_client, + verbose=False + ) + + @pytest.mark.asyncio + async def test_callback_receives_end_of_stream_parameter(self, graph_rag): + """Test that callback receives end_of_stream parameter""" + # Arrange + callback = AsyncMock() + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=callback + ) + + # Assert - callback should receive (chunk, end_of_stream) signature + assert callback.call_count == 4 + # All calls should have 2 arguments + for call_args in callback.call_args_list: + assert len(call_args.args) == 2, "Callback should receive (chunk, end_of_stream)" + + @pytest.mark.asyncio + async def test_end_of_stream_flag_forwarded_correctly(self, graph_rag): + """Test that end_of_stream flags are forwarded correctly""" + # Arrange + chunks_with_flags = [] + + async def collect(chunk, end_of_stream): + chunks_with_flags.append((chunk, end_of_stream)) + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert + assert len(chunks_with_flags) == 4 + + # First three chunks should have end_of_stream=False + assert chunks_with_flags[0] == ("The", False) + assert chunks_with_flags[1] == (" answer", False) + assert chunks_with_flags[2] == (" is here.", False) + + # Final chunk should have end_of_stream=True + assert chunks_with_flags[3] == ("", True) + + @pytest.mark.asyncio + async def test_no_duplicate_final_chunk(self, graph_rag): + """Test that final chunk is not duplicated""" + # Arrange + chunks = [] + + async def collect(chunk, end_of_stream): + chunks.append(chunk) + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert - should have exactly 4 chunks, no duplicates + assert len(chunks) == 4 + assert chunks == ["The", " answer", " is here.", ""] + + # The last chunk appears exactly once + assert chunks.count("") == 1 + + @pytest.mark.asyncio + async def test_exactly_one_end_of_stream_true(self, graph_rag): + """Test that exactly one message has end_of_stream=True""" + # Arrange + end_of_stream_flags = [] + + async def collect(chunk, end_of_stream): + end_of_stream_flags.append(end_of_stream) + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert - exactly one True + assert end_of_stream_flags.count(True) == 1 + assert end_of_stream_flags.count(False) == 3 + + @pytest.mark.asyncio + async def test_empty_final_chunk_preserved(self, graph_rag): + """Test that empty final chunks are preserved and forwarded""" + # Arrange + final_chunk = None + final_flag = None + + async def collect(chunk, end_of_stream): + nonlocal final_chunk, final_flag + if end_of_stream: + final_chunk = chunk + final_flag = end_of_stream + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert + assert final_flag is True + assert final_chunk == "", "Empty final chunk should be preserved" + + +class TestDocumentRagStreamingProtocol: + """Integration tests for DocumentRAG streaming protocol""" + + @pytest.fixture + def mock_embeddings_client(self): + """Mock embeddings client""" + client = AsyncMock() + client.embed.return_value = [[0.1, 0.2, 0.3]] + return client + + @pytest.fixture + def mock_doc_embeddings_client(self): + """Mock document embeddings client""" + client = AsyncMock() + client.query.return_value = ["doc1", "doc2"] + return client + + @pytest.fixture + def mock_streaming_prompt_client(self): + """Mock prompt client with streaming support""" + client = AsyncMock() + + async def document_prompt_side_effect(query, documents, timeout=600, streaming=False, chunk_callback=None): + if streaming and chunk_callback: + # Simulate streaming with non-empty final chunk (some LLMs do this) + await chunk_callback("Document", False) + await chunk_callback(" summary", False) + await chunk_callback(".", True) # Non-empty final chunk + return "" + else: + return "Document summary." + + client.document_prompt.side_effect = document_prompt_side_effect + return client + + @pytest.fixture + def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, + mock_streaming_prompt_client): + """Create DocumentRag instance with mocked dependencies""" + return DocumentRag( + embeddings_client=mock_embeddings_client, + doc_embeddings_client=mock_doc_embeddings_client, + prompt_client=mock_streaming_prompt_client, + verbose=False + ) + + @pytest.mark.asyncio + async def test_callback_receives_end_of_stream_parameter(self, document_rag): + """Test that callback receives end_of_stream parameter""" + # Arrange + callback = AsyncMock() + + # Act + await document_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=callback + ) + + # Assert + assert callback.call_count == 3 + for call_args in callback.call_args_list: + assert len(call_args.args) == 2 + + @pytest.mark.asyncio + async def test_non_empty_final_chunk_preserved(self, document_rag): + """Test that non-empty final chunks are preserved with correct flag""" + # Arrange + chunks_with_flags = [] + + async def collect(chunk, end_of_stream): + chunks_with_flags.append((chunk, end_of_stream)) + + # Act + await document_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert + assert len(chunks_with_flags) == 3 + assert chunks_with_flags[0] == ("Document", False) + assert chunks_with_flags[1] == (" summary", False) + assert chunks_with_flags[2] == (".", True) # Non-empty final chunk + + @pytest.mark.asyncio + async def test_no_duplicate_final_chunk(self, document_rag): + """Test that final chunk is not duplicated""" + # Arrange + chunks = [] + + async def collect(chunk, end_of_stream): + chunks.append(chunk) + + # Act + await document_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert - final "." appears exactly once + assert chunks.count(".") == 1 + assert chunks == ["Document", " summary", "."] + + +class TestStreamingProtocolEdgeCases: + """Test edge cases in streaming protocol""" + + @pytest.mark.asyncio + async def test_multiple_empty_chunks_before_final(self): + """Test handling of multiple empty chunks (edge case)""" + # Arrange + client = AsyncMock() + + async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None): + if streaming and chunk_callback: + await chunk_callback("text", False) + await chunk_callback("", False) # Empty but not final + await chunk_callback("more", False) + await chunk_callback("", True) # Empty and final + return "" + else: + return "textmore" + + client.kg_prompt.side_effect = kg_prompt_with_empties + + rag = GraphRag( + embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])), + graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])), + triples_client=AsyncMock(query=AsyncMock(return_value=[])), + prompt_client=client, + verbose=False + ) + + chunks_with_flags = [] + + async def collect(chunk, end_of_stream): + chunks_with_flags.append((chunk, end_of_stream)) + + # Act + await rag.query( + query="test", + streaming=True, + chunk_callback=collect + ) + + # Assert + assert len(chunks_with_flags) == 4 + assert chunks_with_flags[-1] == ("", True) # Final empty chunk + end_of_stream_flags = [f for c, f in chunks_with_flags] + assert end_of_stream_flags.count(True) == 1 diff --git a/tests/unit/test_base/test_prompt_client_streaming.py b/tests/unit/test_base/test_prompt_client_streaming.py new file mode 100644 index 00000000..83a4b90e --- /dev/null +++ b/tests/unit/test_base/test_prompt_client_streaming.py @@ -0,0 +1,260 @@ +""" +Unit tests for PromptClient streaming callback behavior. + +These tests verify that the prompt client correctly passes the end_of_stream +flag to chunk callbacks, ensuring proper streaming protocol compliance. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call, patch +from trustgraph.base.prompt_client import PromptClient +from trustgraph.schema import PromptResponse + + +class TestPromptClientStreamingCallback: + """Test PromptClient streaming callback behavior""" + + @pytest.fixture + def prompt_client(self): + """Create a PromptClient with mocked dependencies""" + # Mock all the required initialization parameters + with patch.object(PromptClient, '__init__', lambda self: None): + client = PromptClient() + return client + + @pytest.fixture + def mock_request_response(self): + """Create a mock request/response handler""" + async def mock_request(request, recipient=None, timeout=600): + if recipient: + # Simulate streaming responses + responses = [ + PromptResponse(text="Hello", object=None, error=None, end_of_stream=False), + PromptResponse(text=" world", object=None, error=None, end_of_stream=False), + PromptResponse(text="!", object=None, error=None, end_of_stream=False), + PromptResponse(text="", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + else: + # Non-streaming response + return PromptResponse(text="Hello world!", object=None, error=None) + + return mock_request + + @pytest.mark.asyncio + async def test_callback_receives_chunk_and_end_of_stream(self, prompt_client, mock_request_response): + """Test that callback receives both chunk text and end_of_stream flag""" + # Arrange + prompt_client.request = mock_request_response + + callback = AsyncMock() + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=callback + ) + + # Assert - callback should be called with (chunk, end_of_stream) signature + assert callback.call_count == 4 + + # Verify first chunk: text + end_of_stream=False + assert callback.call_args_list[0] == call("Hello", False) + + # Verify second chunk + assert callback.call_args_list[1] == call(" world", False) + + # Verify third chunk + assert callback.call_args_list[2] == call("!", False) + + # Verify final chunk: empty text + end_of_stream=True + assert callback.call_args_list[3] == call("", True) + + @pytest.mark.asyncio + async def test_callback_receives_empty_final_chunk(self, prompt_client, mock_request_response): + """Test that empty final chunks are passed to callback""" + # Arrange + prompt_client.request = mock_request_response + + chunks_received = [] + + async def collect_chunks(chunk, end_of_stream): + chunks_received.append((chunk, end_of_stream)) + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=collect_chunks + ) + + # Assert - should receive the empty final chunk + final_chunk = chunks_received[-1] + assert final_chunk == ("", True), "Final chunk should be empty string with end_of_stream=True" + + @pytest.mark.asyncio + async def test_callback_signature_with_non_empty_final_chunk(self, prompt_client): + """Test callback signature when LLM sends non-empty final chunk""" + # Arrange + async def mock_request_non_empty_final(request, recipient=None, timeout=600): + if recipient: + # Some LLMs send content in the final chunk + responses = [ + PromptResponse(text="Hello", object=None, error=None, end_of_stream=False), + PromptResponse(text=" world!", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request_non_empty_final + + callback = AsyncMock() + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=callback + ) + + # Assert + assert callback.call_count == 2 + assert callback.call_args_list[0] == call("Hello", False) + assert callback.call_args_list[1] == call(" world!", True) + + @pytest.mark.asyncio + async def test_callback_not_called_without_text(self, prompt_client): + """Test that callback is not called for responses without text""" + # Arrange + async def mock_request_no_text(request, recipient=None, timeout=600): + if recipient: + # Response with only end_of_stream, no text + responses = [ + PromptResponse(text="Content", object=None, error=None, end_of_stream=False), + PromptResponse(text=None, object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request_no_text + + callback = AsyncMock() + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=callback + ) + + # Assert - callback should only be called once (for "Content") + assert callback.call_count == 1 + assert callback.call_args_list[0] == call("Content", False) + + @pytest.mark.asyncio + async def test_synchronous_callback_also_receives_end_of_stream(self, prompt_client): + """Test that synchronous callbacks also receive end_of_stream parameter""" + # Arrange + async def mock_request(request, recipient=None, timeout=600): + if recipient: + responses = [ + PromptResponse(text="test", object=None, error=None, end_of_stream=False), + PromptResponse(text="", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request + + callback = MagicMock() # Synchronous mock + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=callback + ) + + # Assert - synchronous callback should also get both parameters + assert callback.call_count == 2 + assert callback.call_args_list[0] == call("test", False) + assert callback.call_args_list[1] == call("", True) + + @pytest.mark.asyncio + async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client): + """Test that kg_prompt correctly passes streaming parameters""" + # Arrange + async def mock_request(request, recipient=None, timeout=600): + if recipient: + responses = [ + PromptResponse(text="Answer", object=None, error=None, end_of_stream=False), + PromptResponse(text="", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request + + callback = AsyncMock() + + # Act + await prompt_client.kg_prompt( + query="What is machine learning?", + kg=[("subject", "predicate", "object")], + streaming=True, + chunk_callback=callback + ) + + # Assert + assert callback.call_count == 2 + assert callback.call_args_list[0] == call("Answer", False) + assert callback.call_args_list[1] == call("", True) + + @pytest.mark.asyncio + async def test_document_prompt_passes_parameters_to_callback(self, prompt_client): + """Test that document_prompt correctly passes streaming parameters""" + # Arrange + async def mock_request(request, recipient=None, timeout=600): + if recipient: + responses = [ + PromptResponse(text="Summary", object=None, error=None, end_of_stream=False), + PromptResponse(text="", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request + + callback = AsyncMock() + + # Act + await prompt_client.document_prompt( + query="Summarize this", + documents=["doc1", "doc2"], + streaming=True, + chunk_callback=callback + ) + + # Assert + assert callback.call_count == 2 + assert callback.call_args_list[0] == call("Summary", False) + assert callback.call_args_list[1] == call("", True) diff --git a/tests/unit/test_gateway/test_socket_graceful_shutdown.py b/tests/unit/test_gateway/test_socket_graceful_shutdown.py index 4e8768a1..50195272 100644 --- a/tests/unit/test_gateway/test_socket_graceful_shutdown.py +++ b/tests/unit/test_gateway/test_socket_graceful_shutdown.py @@ -102,7 +102,7 @@ async def test_handle_normal_flow(): """Test normal websocket handling flow.""" mock_auth = MagicMock() mock_auth.permitted.return_value = True - + dispatcher_created = False async def mock_dispatcher_factory(ws, running, match_info): nonlocal dispatcher_created @@ -110,33 +110,41 @@ async def test_handle_normal_flow(): dispatcher = AsyncMock() dispatcher.destroy = AsyncMock() return dispatcher - + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) - + request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} - + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws - + with patch('asyncio.TaskGroup') as mock_task_group: # Mock task group context manager mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(return_value=None) - mock_tg.create_task = MagicMock(return_value=AsyncMock()) + + # Create proper mock tasks that look like asyncio.Task objects + def create_task_mock(coro): + task = AsyncMock() + task.done = MagicMock(return_value=True) + task.cancelled = MagicMock(return_value=False) + return task + + mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg - + result = await socket_endpoint.handle(request) - + # Should have created dispatcher assert dispatcher_created is True - + # Should return websocket assert result == mock_ws @@ -146,50 +154,58 @@ async def test_handle_exception_group_cleanup(): """Test exception group triggers dispatcher cleanup.""" mock_auth = MagicMock() mock_auth.permitted.return_value = True - + mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() - + async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) - + request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} - + # Mock TaskGroup to raise ExceptionGroup class TestException(Exception): pass - + exception_group = ExceptionGroup("Test exceptions", [TestException("test")]) - + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() - mock_ws.close = AsyncMock() + mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws - + with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(side_effect=exception_group) - mock_tg.create_task = MagicMock(side_effect=TestException("test")) + + # Create proper mock tasks that look like asyncio.Task objects + def create_task_mock(coro): + task = AsyncMock() + task.done = MagicMock(return_value=True) + task.cancelled = MagicMock(return_value=False) + return task + + mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg - + with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for: mock_wait_for.return_value = None - + result = await socket_endpoint.handle(request) - + # Should have attempted graceful cleanup mock_wait_for.assert_called_once() - + # Should have called destroy in finally block assert mock_dispatcher.destroy.call_count >= 1 - + # Should have closed websocket mock_ws.close.assert_called() @@ -199,48 +215,56 @@ async def test_handle_dispatcher_cleanup_timeout(): """Test dispatcher cleanup with timeout.""" mock_auth = MagicMock() mock_auth.permitted.return_value = True - + # Mock dispatcher that takes long to destroy mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() - + async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) - + request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} - + # Mock TaskGroup to raise exception exception_group = ExceptionGroup("Test", [Exception("test")]) - + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws - + with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(side_effect=exception_group) - mock_tg.create_task = MagicMock(side_effect=Exception("test")) + + # Create proper mock tasks that look like asyncio.Task objects + def create_task_mock(coro): + task = AsyncMock() + task.done = MagicMock(return_value=True) + task.cancelled = MagicMock(return_value=False) + return task + + mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg - + # Mock asyncio.wait_for to raise TimeoutError with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for: mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout") - + result = await socket_endpoint.handle(request) - + # Should have attempted cleanup with timeout mock_wait_for.assert_called_once() # Check that timeout was passed correctly assert mock_wait_for.call_args[1]['timeout'] == 5.0 - + # Should still call destroy in finally block assert mock_dispatcher.destroy.call_count >= 1 @@ -290,37 +314,45 @@ async def test_handle_websocket_already_closed(): """Test handling when websocket is already closed.""" mock_auth = MagicMock() mock_auth.permitted.return_value = True - + mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() - + async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) - + request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} - + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = True # Already closed mock_ws_class.return_value = mock_ws - + with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(return_value=None) - mock_tg.create_task = MagicMock(return_value=AsyncMock()) + + # Create proper mock tasks that look like asyncio.Task objects + def create_task_mock(coro): + task = AsyncMock() + task.done = MagicMock(return_value=True) + task.cancelled = MagicMock(return_value=False) + return task + + mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg - + result = await socket_endpoint.handle(request) - + # Should still have called destroy mock_dispatcher.destroy.assert_called() - + # Should not attempt to close already closed websocket mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True \ No newline at end of file diff --git a/tests/unit/test_gateway/test_streaming_translators.py b/tests/unit/test_gateway/test_streaming_translators.py new file mode 100644 index 00000000..e767edd4 --- /dev/null +++ b/tests/unit/test_gateway/test_streaming_translators.py @@ -0,0 +1,326 @@ +""" +Unit tests for streaming behavior in message translators. + +These tests verify that translators correctly handle empty strings and +end_of_stream flags in streaming responses, preventing bugs where empty +final chunks could be dropped due to falsy value checks. +""" + +import pytest +from unittest.mock import MagicMock +from trustgraph.messaging.translators.retrieval import ( + GraphRagResponseTranslator, + DocumentRagResponseTranslator, +) +from trustgraph.messaging.translators.prompt import PromptResponseTranslator +from trustgraph.messaging.translators.text_completion import TextCompletionResponseTranslator +from trustgraph.schema import ( + GraphRagResponse, + DocumentRagResponse, + PromptResponse, + TextCompletionResponse, +) + + +class TestGraphRagResponseTranslator: + """Test GraphRagResponseTranslator streaming behavior""" + + def test_from_pulsar_with_empty_response(self): + """Test that empty response strings are preserved""" + # Arrange + translator = GraphRagResponseTranslator() + response = GraphRagResponse( + response="", + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert - Empty string should be included in result + assert "response" in result + assert result["response"] == "" + assert result["end_of_stream"] is True + + def test_from_pulsar_with_non_empty_response(self): + """Test that non-empty responses work correctly""" + # Arrange + translator = GraphRagResponseTranslator() + response = GraphRagResponse( + response="Some text", + end_of_stream=False, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert result["response"] == "Some text" + assert result["end_of_stream"] is False + + def test_from_pulsar_with_none_response(self): + """Test that None response is handled correctly""" + # Arrange + translator = GraphRagResponseTranslator() + response = GraphRagResponse( + response=None, + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert - None should not be included + assert "response" not in result + assert result["end_of_stream"] is True + + def test_from_response_with_completion_returns_correct_flag(self): + """Test that from_response_with_completion returns correct is_final flag""" + # Arrange + translator = GraphRagResponseTranslator() + + # Test non-final chunk + response_chunk = GraphRagResponse( + response="chunk", + end_of_stream=False, + error=None + ) + + # Act + result, is_final = translator.from_response_with_completion(response_chunk) + + # Assert + assert is_final is False + assert result["end_of_stream"] is False + + # Test final chunk with empty content + final_response = GraphRagResponse( + response="", + end_of_stream=True, + error=None + ) + + # Act + result, is_final = translator.from_response_with_completion(final_response) + + # Assert + assert is_final is True + assert result["response"] == "" + assert result["end_of_stream"] is True + + +class TestDocumentRagResponseTranslator: + """Test DocumentRagResponseTranslator streaming behavior""" + + def test_from_pulsar_with_empty_response(self): + """Test that empty response strings are preserved""" + # Arrange + translator = DocumentRagResponseTranslator() + response = DocumentRagResponse( + response="", + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "response" in result + assert result["response"] == "" + assert result["end_of_stream"] is True + + def test_from_pulsar_with_non_empty_response(self): + """Test that non-empty responses work correctly""" + # Arrange + translator = DocumentRagResponseTranslator() + response = DocumentRagResponse( + response="Document content", + end_of_stream=False, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert result["response"] == "Document content" + assert result["end_of_stream"] is False + + +class TestPromptResponseTranslator: + """Test PromptResponseTranslator streaming behavior""" + + def test_from_pulsar_with_empty_text(self): + """Test that empty text strings are preserved""" + # Arrange + translator = PromptResponseTranslator() + response = PromptResponse( + text="", + object=None, + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "text" in result + assert result["text"] == "" + assert result["end_of_stream"] is True + + def test_from_pulsar_with_non_empty_text(self): + """Test that non-empty text works correctly""" + # Arrange + translator = PromptResponseTranslator() + response = PromptResponse( + text="Some prompt response", + object=None, + end_of_stream=False, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert result["text"] == "Some prompt response" + assert result["end_of_stream"] is False + + def test_from_pulsar_with_none_text(self): + """Test that None text is handled correctly""" + # Arrange + translator = PromptResponseTranslator() + response = PromptResponse( + text=None, + object='{"result": "data"}', + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "text" not in result + assert "object" in result + assert result["end_of_stream"] is True + + def test_from_pulsar_includes_end_of_stream(self): + """Test that end_of_stream flag is always included""" + # Arrange + translator = PromptResponseTranslator() + + # Test with end_of_stream=False + response = PromptResponse( + text="chunk", + object=None, + end_of_stream=False, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "end_of_stream" in result + assert result["end_of_stream"] is False + + +class TestTextCompletionResponseTranslator: + """Test TextCompletionResponseTranslator streaming behavior""" + + def test_from_pulsar_always_includes_response(self): + """Test that response field is always included, even if empty""" + # Arrange + translator = TextCompletionResponseTranslator() + response = TextCompletionResponse( + response="", + end_of_stream=True, + error=None, + in_token=100, + out_token=5, + model="test-model" + ) + + # Act + result = translator.from_pulsar(response) + + # Assert - Response should always be present + assert "response" in result + assert result["response"] == "" + + def test_from_response_with_completion_with_empty_final(self): + """Test that empty final response is handled correctly""" + # Arrange + translator = TextCompletionResponseTranslator() + response = TextCompletionResponse( + response="", + end_of_stream=True, + error=None, + in_token=100, + out_token=5, + model="test-model" + ) + + # Act + result, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is True + assert result["response"] == "" + + +class TestStreamingProtocolCompliance: + """Test that all translators follow streaming protocol conventions""" + + @pytest.mark.parametrize("translator_class,response_class,field_name", [ + (GraphRagResponseTranslator, GraphRagResponse, "response"), + (DocumentRagResponseTranslator, DocumentRagResponse, "response"), + (PromptResponseTranslator, PromptResponse, "text"), + (TextCompletionResponseTranslator, TextCompletionResponse, "response"), + ]) + def test_empty_final_chunk_preserved(self, translator_class, response_class, field_name): + """Test that all translators preserve empty final chunks""" + # Arrange + translator = translator_class() + kwargs = { + field_name: "", + "end_of_stream": True, + "error": None, + } + response = response_class(**kwargs) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty" + assert result[field_name] == "", f"{translator_class.__name__} should preserve empty string" + + @pytest.mark.parametrize("translator_class,response_class,field_name", [ + (GraphRagResponseTranslator, GraphRagResponse, "response"), + (DocumentRagResponseTranslator, DocumentRagResponse, "response"), + (TextCompletionResponseTranslator, TextCompletionResponse, "response"), + ]) + def test_end_of_stream_flag_included(self, translator_class, response_class, field_name): + """Test that end_of_stream flag is included in all response translators""" + # Arrange + translator = translator_class() + kwargs = { + field_name: "test content", + "end_of_stream": True, + "error": None, + } + response = response_class(**kwargs) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag" + assert result["end_of_stream"] is True diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index 370cf78a..74b25132 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -51,7 +51,8 @@ class PromptClient(RequestResponse): end_stream = getattr(resp, 'end_of_stream', False) - if resp.text: + # Always call callback if there's text OR if it's the final message + if resp.text is not None: last_text = resp.text # Call chunk callback if provided with both chunk and end_of_stream flag if chunk_callback: diff --git a/trustgraph-base/trustgraph/messaging/translators/text_completion.py b/trustgraph-base/trustgraph/messaging/translators/text_completion.py index b4ba4d13..fa3749b5 100644 --- a/trustgraph-base/trustgraph/messaging/translators/text_completion.py +++ b/trustgraph-base/trustgraph/messaging/translators/text_completion.py @@ -28,14 +28,17 @@ class TextCompletionResponseTranslator(MessageTranslator): def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]: result = {"response": obj.response} - + if obj.in_token: result["in_token"] = obj.in_token if obj.out_token: result["out_token"] = obj.out_token if obj.model: result["model"] = obj.model - + + # Always include end_of_stream flag for streaming support + result["end_of_stream"] = getattr(obj, "end_of_stream", False) + return result def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]: