diff --git a/tests/integration/test_import_export_graceful_shutdown.py b/tests/integration/test_import_export_graceful_shutdown.py new file mode 100644 index 00000000..e80e4514 --- /dev/null +++ b/tests/integration/test_import_export_graceful_shutdown.py @@ -0,0 +1,454 @@ +"""Integration tests for import/export graceful shutdown functionality.""" + +import pytest +import asyncio +import json +import time +from unittest.mock import AsyncMock, MagicMock, patch +from aiohttp import web, WSMsgType, ClientWebSocketResponse +from trustgraph.gateway.dispatch.triples_import import TriplesImport +from trustgraph.gateway.dispatch.triples_export import TriplesExport +from trustgraph.gateway.running import Running +from trustgraph.base.publisher import Publisher +from trustgraph.base.subscriber import Subscriber + + +class MockPulsarMessage: + """Mock Pulsar message for testing.""" + + def __init__(self, data, message_id="test-id"): + self._data = data + self._message_id = message_id + self._properties = {"id": message_id} + + def value(self): + return self._data + + def properties(self): + return self._properties + + +class MockWebSocket: + """Mock WebSocket for testing.""" + + def __init__(self): + self.messages = [] + self.closed = False + self._close_called = False + + async def send_json(self, data): + if self.closed: + raise Exception("WebSocket is closed") + self.messages.append(data) + + async def close(self): + self._close_called = True + self.closed = True + + def json(self): + """Mock message json() method.""" + return { + "metadata": { + "id": "test-id", + "metadata": {}, + "user": "test-user", + "collection": "test-collection" + }, + "triples": [["subject", "predicate", "object"]] + } + + +@pytest.fixture +def mock_pulsar_client(): + """Mock Pulsar client for integration testing.""" + client = MagicMock() + + # Mock producer + producer = MagicMock() + producer.send = MagicMock() + producer.flush = MagicMock() + producer.close = MagicMock() + client.create_producer.return_value = producer + + # Mock consumer + consumer = MagicMock() + consumer.receive = AsyncMock() + consumer.acknowledge = MagicMock() + consumer.negative_acknowledge = MagicMock() + consumer.pause_message_listener = MagicMock() + consumer.unsubscribe = MagicMock() + consumer.close = MagicMock() + client.subscribe.return_value = consumer + + return client + + +@pytest.mark.asyncio +async def test_import_graceful_shutdown_integration(): + """Test import path handles shutdown gracefully with real message flow.""" + mock_client = MagicMock() + mock_producer = MagicMock() + mock_client.create_producer.return_value = mock_producer + + # Track sent messages + sent_messages = [] + def track_send(message, properties=None): + sent_messages.append((message, properties)) + + mock_producer.send.side_effect = track_send + + ws = MockWebSocket() + running = Running() + + # Create import handler + import_handler = TriplesImport( + ws=ws, + running=running, + pulsar_client=mock_client, + queue="test-triples-import" + ) + + await import_handler.start() + + # Send multiple messages rapidly + messages = [] + for i in range(10): + msg_data = { + "metadata": { + "id": f"msg-{i}", + "metadata": {}, + "user": "test-user", + "collection": "test-collection" + }, + "triples": [[f"subject-{i}", "predicate", f"object-{i}"]] + } + messages.append(msg_data) + + # Create mock message with json() method + mock_msg = MagicMock() + mock_msg.json.return_value = msg_data + + await import_handler.receive(mock_msg) + + # Allow brief processing time + await asyncio.sleep(0.1) + + # Shutdown while messages may be in flight + await import_handler.destroy() + + # Verify all messages reached producer + assert len(sent_messages) == 10 + + # Verify proper shutdown order was followed + mock_producer.flush.assert_called_once() + mock_producer.close.assert_called_once() + + # Verify messages have correct content + for i, (message, properties) in enumerate(sent_messages): + assert message.metadata.id == f"msg-{i}" + assert len(message.triples) == 1 + assert message.triples[0][0] == f"subject-{i}" + + +@pytest.mark.asyncio +async def test_export_no_message_loss_integration(): + """Test export path doesn't lose acknowledged messages.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + # Create test messages + test_messages = [] + for i in range(20): + msg_data = { + "metadata": { + "id": f"export-msg-{i}", + "metadata": {}, + "user": "test-user", + "collection": "test-collection" + }, + "triples": [[f"export-subject-{i}", "predicate", f"export-object-{i}"]] + } + test_messages.append(MockPulsarMessage(msg_data, f"export-msg-{i}")) + + # Mock consumer to provide messages + message_iter = iter(test_messages) + async def mock_receive(): + try: + return next(message_iter) + except StopIteration: + # Simulate no more messages + await asyncio.sleep(1) + raise StopIteration + + mock_consumer.receive = mock_receive + + ws = MockWebSocket() + running = Running() + + # Create export handler + export_handler = TriplesExport( + ws=ws, + running=running, + pulsar_client=mock_client, + queue="test-triples-export", + consumer="test-consumer", + subscriber="test-subscriber" + ) + + # Start export in background + export_task = asyncio.create_task(export_handler.run()) + + # Allow some messages to be processed + await asyncio.sleep(0.5) + + # Verify some messages were sent to websocket + initial_count = len(ws.messages) + assert initial_count > 0 + + # Force shutdown + await export_handler.destroy() + + # Wait for export task to complete + try: + await asyncio.wait_for(export_task, timeout=2.0) + except asyncio.TimeoutError: + export_task.cancel() + + # Verify websocket was closed + assert ws._close_called is True + + # Verify messages that were acknowledged were actually sent + final_count = len(ws.messages) + assert final_count >= initial_count + + # Verify no partial/corrupted messages + for msg in ws.messages: + assert "metadata" in msg + assert "triples" in msg + assert msg["metadata"]["id"].startswith("export-msg-") + + +@pytest.mark.asyncio +async def test_concurrent_import_export_shutdown(): + """Test concurrent import and export shutdown scenarios.""" + # Setup mock clients + import_client = MagicMock() + export_client = MagicMock() + + import_producer = MagicMock() + export_consumer = MagicMock() + + import_client.create_producer.return_value = import_producer + export_client.subscribe.return_value = export_consumer + + # Track operations + import_operations = [] + export_operations = [] + + def track_import_send(message, properties=None): + import_operations.append(("send", message.metadata.id)) + + def track_import_flush(): + import_operations.append(("flush",)) + + def track_export_ack(msg): + export_operations.append(("ack", msg.properties()["id"])) + + import_producer.send.side_effect = track_import_send + import_producer.flush.side_effect = track_import_flush + export_consumer.acknowledge.side_effect = track_export_ack + + # Create handlers + import_ws = MockWebSocket() + export_ws = MockWebSocket() + import_running = Running() + export_running = Running() + + import_handler = TriplesImport( + ws=import_ws, + running=import_running, + pulsar_client=import_client, + queue="concurrent-import" + ) + + export_handler = TriplesExport( + ws=export_ws, + running=export_running, + pulsar_client=export_client, + queue="concurrent-export", + consumer="concurrent-consumer", + subscriber="concurrent-subscriber" + ) + + # Start both handlers + await import_handler.start() + + # Send messages to import + for i in range(5): + msg = MagicMock() + msg.json.return_value = { + "metadata": { + "id": f"concurrent-{i}", + "metadata": {}, + "user": "test-user", + "collection": "test-collection" + }, + "triples": [[f"concurrent-subject-{i}", "predicate", "object"]] + } + await import_handler.receive(msg) + + # Shutdown both concurrently + import_shutdown = asyncio.create_task(import_handler.destroy()) + export_shutdown = asyncio.create_task(export_handler.destroy()) + + await asyncio.gather(import_shutdown, export_shutdown) + + # Verify import operations completed properly + assert len(import_operations) == 6 # 5 sends + 1 flush + assert ("flush",) in import_operations + + # Verify all import messages were processed + send_ops = [op for op in import_operations if op[0] == "send"] + assert len(send_ops) == 5 + + +@pytest.mark.asyncio +async def test_websocket_close_during_message_processing(): + """Test graceful handling when websocket closes during active message processing.""" + mock_client = MagicMock() + mock_producer = MagicMock() + mock_client.create_producer.return_value = mock_producer + + # Simulate slow message processing + processed_messages = [] + async def slow_send(message, properties=None): + processed_messages.append(message.metadata.id) + await asyncio.sleep(0.1) # Simulate processing delay + + mock_producer.send.side_effect = slow_send + + ws = MockWebSocket() + running = Running() + + import_handler = TriplesImport( + ws=ws, + running=running, + pulsar_client=mock_client, + queue="slow-processing-import" + ) + + await import_handler.start() + + # Send many messages rapidly + message_tasks = [] + for i in range(10): + msg = MagicMock() + msg.json.return_value = { + "metadata": { + "id": f"slow-msg-{i}", + "metadata": {}, + "user": "test-user", + "collection": "test-collection" + }, + "triples": [[f"slow-subject-{i}", "predicate", "object"]] + } + task = asyncio.create_task(import_handler.receive(msg)) + message_tasks.append(task) + + # Allow some processing to start + await asyncio.sleep(0.05) + + # Close websocket while messages are being processed + ws.closed = True + + # Shutdown handler + await import_handler.destroy() + + # Wait for all message tasks to complete + await asyncio.gather(*message_tasks, return_exceptions=True) + + # Verify that messages that were being processed completed + # (graceful shutdown should allow in-flight processing to finish) + assert len(processed_messages) > 0 + + # Verify producer was properly flushed and closed + mock_producer.flush.assert_called_once() + mock_producer.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_backpressure_during_shutdown(): + """Test graceful shutdown under backpressure conditions.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + # Create messages that will cause backpressure + large_messages = [] + for i in range(50): + msg_data = { + "metadata": { + "id": f"large-msg-{i}", + "metadata": {"large_field": "x" * 1000}, # Large metadata + "user": "test-user", + "collection": "test-collection" + }, + "triples": [[f"large-subject-{i}", "predicate", f"large-object-{i}"]] + } + large_messages.append(MockPulsarMessage(msg_data, f"large-msg-{i}")) + + # Mock slow websocket + class SlowWebSocket(MockWebSocket): + async def send_json(self, data): + await asyncio.sleep(0.02) # Slow send + await super().send_json(data) + + ws = SlowWebSocket() + running = Running() + + export_handler = TriplesExport( + ws=ws, + running=running, + pulsar_client=mock_client, + queue="backpressure-export", + consumer="backpressure-consumer", + subscriber="backpressure-subscriber" + ) + + # Mock consumer with backpressure + message_queue = asyncio.Queue(maxsize=5) # Small queue + for msg in large_messages[:10]: # Only add first 10 + await message_queue.put(msg) + + async def mock_receive_with_backpressure(): + return await message_queue.get() + + mock_consumer.receive = mock_receive_with_backpressure + + # Start export task + export_task = asyncio.create_task(export_handler.run()) + + # Allow some processing + await asyncio.sleep(0.3) + + # Shutdown under backpressure + shutdown_start = time.time() + await export_handler.destroy() + shutdown_duration = time.time() - shutdown_start + + # Cancel export task + export_task.cancel() + try: + await export_task + except asyncio.CancelledError: + pass + + # Verify graceful shutdown completed within reasonable time + assert shutdown_duration < 10.0 # Should not hang indefinitely + + # Verify some messages were processed before shutdown + assert len(ws.messages) > 0 + + # Verify websocket was closed + assert ws._close_called is True \ No newline at end of file diff --git a/tests/unit/test_base/test_publisher_graceful_shutdown.py b/tests/unit/test_base/test_publisher_graceful_shutdown.py new file mode 100644 index 00000000..e15cb1ec --- /dev/null +++ b/tests/unit/test_base/test_publisher_graceful_shutdown.py @@ -0,0 +1,330 @@ +"""Unit tests for Publisher graceful shutdown functionality.""" + +import pytest +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock, patch +from trustgraph.base.publisher import Publisher + + +@pytest.fixture +def mock_pulsar_client(): + """Mock Pulsar client for testing.""" + client = MagicMock() + producer = AsyncMock() + producer.send = MagicMock() + producer.flush = MagicMock() + producer.close = MagicMock() + client.create_producer.return_value = producer + return client + + +@pytest.fixture +def publisher(mock_pulsar_client): + """Create Publisher instance for testing.""" + return Publisher( + client=mock_pulsar_client, + topic="test-topic", + schema=dict, + max_size=10, + drain_timeout=2.0 + ) + + +@pytest.mark.asyncio +async def test_publisher_queue_drain(): + """Verify Publisher drains queue on shutdown.""" + mock_client = MagicMock() + mock_producer = MagicMock() + mock_client.create_producer.return_value = mock_producer + + publisher = Publisher( + client=mock_client, + topic="test-topic", + schema=dict, + max_size=10, + drain_timeout=1.0 # Shorter timeout for testing + ) + + # Don't start the actual run loop - just test the drain logic + # Fill queue with messages directly + for i in range(5): + await publisher.q.put((f"id-{i}", {"data": i})) + + # Verify queue has messages + assert not publisher.q.empty() + + # Mock the producer creation in run() method by patching + with patch.object(publisher, 'run') as mock_run: + # Create a realistic run implementation that processes the queue + async def mock_run_impl(): + # Simulate the actual run logic for drain + producer = mock_producer + while not publisher.q.empty(): + try: + id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.1) + producer.send(item, {"id": id}) + except asyncio.TimeoutError: + break + producer.flush() + producer.close() + + mock_run.side_effect = mock_run_impl + + # Start and stop publisher + await publisher.start() + await publisher.stop() + + # Verify all messages were sent + assert publisher.q.empty() + assert mock_producer.send.call_count == 5 + mock_producer.flush.assert_called_once() + mock_producer.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_publisher_rejects_messages_during_drain(): + """Verify Publisher rejects new messages during shutdown.""" + mock_client = MagicMock() + mock_producer = MagicMock() + mock_client.create_producer.return_value = mock_producer + + publisher = Publisher( + client=mock_client, + topic="test-topic", + schema=dict, + max_size=10, + drain_timeout=1.0 + ) + + # Don't start the actual run loop + # Add one message directly + await publisher.q.put(("id-1", {"data": 1})) + + # Start shutdown process manually + publisher.running = False + publisher.draining = True + + # Try to send message during drain + with pytest.raises(RuntimeError, match="Publisher is shutting down"): + await publisher.send("id-2", {"data": 2}) + + +@pytest.mark.asyncio +async def test_publisher_drain_timeout(): + """Verify Publisher respects drain timeout.""" + mock_client = MagicMock() + mock_producer = MagicMock() + mock_client.create_producer.return_value = mock_producer + + publisher = Publisher( + client=mock_client, + topic="test-topic", + schema=dict, + max_size=10, + drain_timeout=0.2 # Short timeout for testing + ) + + # Fill queue with many messages directly + for i in range(10): + await publisher.q.put((f"id-{i}", {"data": i})) + + # Mock slow message processing + def slow_send(*args, **kwargs): + time.sleep(0.1) # Simulate slow send + + mock_producer.send.side_effect = slow_send + + with patch.object(publisher, 'run') as mock_run: + # Create a run implementation that respects timeout + async def mock_run_with_timeout(): + producer = mock_producer + end_time = time.time() + publisher.drain_timeout + + while not publisher.q.empty() and time.time() < end_time: + try: + id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.05) + producer.send(item, {"id": id}) + except asyncio.TimeoutError: + break + + producer.flush() + producer.close() + + mock_run.side_effect = mock_run_with_timeout + + start_time = time.time() + await publisher.start() + await publisher.stop() + end_time = time.time() + + # Should timeout quickly + assert end_time - start_time < 1.0 + + # Should have called flush and close even with timeout + mock_producer.flush.assert_called_once() + mock_producer.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_publisher_successful_drain(): + """Verify Publisher drains successfully under normal conditions.""" + mock_client = MagicMock() + mock_producer = MagicMock() + mock_client.create_producer.return_value = mock_producer + + publisher = Publisher( + client=mock_client, + topic="test-topic", + schema=dict, + max_size=10, + drain_timeout=2.0 + ) + + # Add messages directly to queue + messages = [] + for i in range(3): + msg = {"data": i} + await publisher.q.put((f"id-{i}", msg)) + messages.append(msg) + + with patch.object(publisher, 'run') as mock_run: + # Create a successful drain implementation + async def mock_successful_drain(): + producer = mock_producer + processed = [] + + while not publisher.q.empty(): + id, item = await publisher.q.get() + producer.send(item, {"id": id}) + processed.append((id, item)) + + producer.flush() + producer.close() + return processed + + mock_run.side_effect = mock_successful_drain + + await publisher.start() + await publisher.stop() + + # All messages should be sent + assert publisher.q.empty() + assert mock_producer.send.call_count == 3 + + # Verify correct messages were sent + sent_calls = mock_producer.send.call_args_list + for i, call in enumerate(sent_calls): + args, kwargs = call + assert args[0] == {"data": i} # message content + # Note: kwargs format depends on how send was called in mock + # Just verify message was sent with correct content + + +@pytest.mark.asyncio +async def test_publisher_state_transitions(): + """Test Publisher state transitions during graceful shutdown.""" + mock_client = MagicMock() + mock_producer = MagicMock() + mock_client.create_producer.return_value = mock_producer + + publisher = Publisher( + client=mock_client, + topic="test-topic", + schema=dict, + max_size=10, + drain_timeout=1.0 + ) + + # Initial state + assert publisher.running is True + assert publisher.draining is False + + # Add message directly + await publisher.q.put(("id-1", {"data": 1})) + + with patch.object(publisher, 'run') as mock_run: + # Mock run that simulates state transitions + async def mock_run_with_states(): + # Simulate drain process + publisher.running = False + publisher.draining = True + + # Process messages + while not publisher.q.empty(): + id, item = await publisher.q.get() + mock_producer.send(item, {"id": id}) + + # Complete drain + publisher.draining = False + mock_producer.flush() + mock_producer.close() + + mock_run.side_effect = mock_run_with_states + + await publisher.start() + await publisher.stop() + + # Should have completed all state transitions + assert publisher.running is False + assert publisher.draining is False + mock_producer.send.assert_called_once() + mock_producer.flush.assert_called_once() + mock_producer.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_publisher_exception_handling(): + """Test Publisher handles exceptions during drain gracefully.""" + mock_client = MagicMock() + mock_producer = MagicMock() + mock_client.create_producer.return_value = mock_producer + + # Mock producer.send to raise exception on second call + call_count = 0 + def failing_send(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise Exception("Send failed") + + mock_producer.send.side_effect = failing_send + + publisher = Publisher( + client=mock_client, + topic="test-topic", + schema=dict, + max_size=10, + drain_timeout=1.0 + ) + + # Add messages directly + await publisher.q.put(("id-1", {"data": 1})) + await publisher.q.put(("id-2", {"data": 2})) + + with patch.object(publisher, 'run') as mock_run: + # Mock run that handles exceptions gracefully + async def mock_run_with_exceptions(): + producer = mock_producer + + while not publisher.q.empty(): + try: + id, item = await publisher.q.get() + producer.send(item, {"id": id}) + except Exception as e: + # Log exception but continue processing + continue + + # Always call flush and close + producer.flush() + producer.close() + + mock_run.side_effect = mock_run_with_exceptions + + await publisher.start() + await publisher.stop() + + # Should have attempted to send both messages + assert mock_producer.send.call_count == 2 + mock_producer.flush.assert_called_once() + mock_producer.close.assert_called_once() \ No newline at end of file diff --git a/tests/unit/test_base/test_subscriber_graceful_shutdown.py b/tests/unit/test_base/test_subscriber_graceful_shutdown.py new file mode 100644 index 00000000..6f690d4a --- /dev/null +++ b/tests/unit/test_base/test_subscriber_graceful_shutdown.py @@ -0,0 +1,315 @@ +"""Unit tests for Subscriber graceful shutdown functionality.""" + +import pytest +import asyncio +import uuid +from unittest.mock import AsyncMock, MagicMock, patch +from trustgraph.base.subscriber import Subscriber + + +@pytest.fixture +def mock_pulsar_client(): + """Mock Pulsar client for testing.""" + client = MagicMock() + consumer = MagicMock() + consumer.receive = MagicMock() + consumer.acknowledge = MagicMock() + consumer.negative_acknowledge = MagicMock() + consumer.pause_message_listener = MagicMock() + consumer.unsubscribe = MagicMock() + consumer.close = MagicMock() + client.subscribe.return_value = consumer + return client + + +@pytest.fixture +def subscriber(mock_pulsar_client): + """Create Subscriber instance for testing.""" + return Subscriber( + client=mock_pulsar_client, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=10, + drain_timeout=2.0, + backpressure_strategy="block" + ) + + +def create_mock_message(message_id="test-id", data=None): + """Create a mock Pulsar message.""" + msg = MagicMock() + msg.properties.return_value = {"id": message_id} + msg.value.return_value = data or {"test": "data"} + return msg + + +@pytest.mark.asyncio +async def test_subscriber_deferred_acknowledgment_success(): + """Verify Subscriber only acks on successful delivery.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + subscriber = Subscriber( + client=mock_client, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=10, + backpressure_strategy="block" + ) + + # Create queue for subscription + queue = await subscriber.subscribe("test-queue") + + # Create mock message + msg = create_mock_message("msg-1", {"data": "test"}) + + # Process message + await subscriber._process_message(msg) + + # Should acknowledge successful delivery + mock_consumer.acknowledge.assert_called_once_with(msg) + mock_consumer.negative_acknowledge.assert_not_called() + + # Message should be in queue + assert not queue.empty() + received_msg = await queue.get() + assert received_msg == {"data": "test"} + + +@pytest.mark.asyncio +async def test_subscriber_deferred_acknowledgment_failure(): + """Verify Subscriber negative acks on delivery failure.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + subscriber = Subscriber( + client=mock_client, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=1, # Very small queue + backpressure_strategy="drop_new" + ) + + # Create queue and fill it + queue = await subscriber.subscribe("test-queue") + await queue.put({"existing": "data"}) + + # Create mock message - should be dropped + msg = create_mock_message("msg-1", {"data": "test"}) + + # Process message (should fail due to full queue + drop_new strategy) + await subscriber._process_message(msg) + + # Should negative acknowledge failed delivery + mock_consumer.negative_acknowledge.assert_called_once_with(msg) + mock_consumer.acknowledge.assert_not_called() + + +@pytest.mark.asyncio +async def test_subscriber_backpressure_strategies(): + """Test different backpressure strategies.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + # Test drop_oldest strategy + subscriber = Subscriber( + client=mock_client, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=2, + backpressure_strategy="drop_oldest" + ) + + queue = await subscriber.subscribe("test-queue") + + # Fill queue + await queue.put({"data": "old1"}) + await queue.put({"data": "old2"}) + + # Add new message (should drop oldest) + msg = create_mock_message("msg-1", {"data": "new"}) + await subscriber._process_message(msg) + + # Should acknowledge delivery + mock_consumer.acknowledge.assert_called_once_with(msg) + + # Queue should have new message (old one dropped) + messages = [] + while not queue.empty(): + messages.append(await queue.get()) + + # Should contain old2 and new (old1 was dropped) + assert len(messages) == 2 + assert {"data": "new"} in messages + + +@pytest.mark.asyncio +async def test_subscriber_graceful_shutdown(): + """Test Subscriber graceful shutdown with queue draining.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + subscriber = Subscriber( + client=mock_client, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=10, + drain_timeout=1.0 + ) + + await subscriber.start() + + # Create subscription with messages + queue = await subscriber.subscribe("test-queue") + await queue.put({"data": "msg1"}) + await queue.put({"data": "msg2"}) + + # Initial state + assert subscriber.running is True + assert subscriber.draining is False + + # Start shutdown + stop_task = asyncio.create_task(subscriber.stop()) + + # Allow brief processing + await asyncio.sleep(0.1) + + # Should be in drain state + assert subscriber.running is False + assert subscriber.draining is True + + # Should pause message listener + mock_consumer.pause_message_listener.assert_called_once() + + # Complete shutdown + await stop_task + + # Should have cleaned up + mock_consumer.unsubscribe.assert_called_once() + mock_consumer.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_subscriber_drain_timeout(): + """Test Subscriber respects drain timeout.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + subscriber = Subscriber( + client=mock_client, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=10, + drain_timeout=0.1 # Very short timeout + ) + + await subscriber.start() + + # Create subscription with many messages + queue = await subscriber.subscribe("test-queue") + for i in range(20): + await queue.put({"data": f"msg{i}"}) + + import time + start_time = time.time() + await subscriber.stop() + end_time = time.time() + + # Should timeout quickly + assert end_time - start_time < 1.0 + + # Queue should still have messages (drain timed out) + assert not queue.empty() + + +@pytest.mark.asyncio +async def test_subscriber_pending_acks_cleanup(): + """Test Subscriber cleans up pending acknowledgments on shutdown.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + subscriber = Subscriber( + client=mock_client, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=10 + ) + + await subscriber.start() + + # Add pending acknowledgments manually (simulating in-flight messages) + msg1 = create_mock_message("msg-1") + msg2 = create_mock_message("msg-2") + subscriber.pending_acks["ack-1"] = msg1 + subscriber.pending_acks["ack-2"] = msg2 + + # Stop subscriber + await subscriber.stop() + + # Should negative acknowledge pending messages + assert mock_consumer.negative_acknowledge.call_count == 2 + mock_consumer.negative_acknowledge.assert_any_call(msg1) + mock_consumer.negative_acknowledge.assert_any_call(msg2) + + # Pending acks should be cleared + assert len(subscriber.pending_acks) == 0 + + +@pytest.mark.asyncio +async def test_subscriber_multiple_subscribers(): + """Test Subscriber with multiple concurrent subscribers.""" + mock_client = MagicMock() + mock_consumer = MagicMock() + mock_client.subscribe.return_value = mock_consumer + + subscriber = Subscriber( + client=mock_client, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=10 + ) + + # Create multiple subscriptions + queue1 = await subscriber.subscribe("queue-1") + queue2 = await subscriber.subscribe("queue-2") + queue_all = await subscriber.subscribe_all("queue-all") + + # Process message + msg = create_mock_message("msg-1", {"data": "broadcast"}) + await subscriber._process_message(msg) + + # Should acknowledge (successful delivery to all queues) + mock_consumer.acknowledge.assert_called_once_with(msg) + + # Message should be in specific queue (queue-1) and broadcast queue + assert not queue1.empty() + assert queue2.empty() # No message for queue-2 + assert not queue_all.empty() + + # Verify message content + msg1 = await queue1.get() + msg_all = await queue_all.get() + assert msg1 == {"data": "broadcast"} + assert msg_all == {"data": "broadcast"} \ No newline at end of file diff --git a/tests/unit/test_gateway/test_socket_graceful_shutdown.py b/tests/unit/test_gateway/test_socket_graceful_shutdown.py new file mode 100644 index 00000000..9202e615 --- /dev/null +++ b/tests/unit/test_gateway/test_socket_graceful_shutdown.py @@ -0,0 +1,324 @@ +"""Unit tests for SocketEndpoint graceful shutdown functionality.""" + +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from aiohttp import web, WSMsgType +from trustgraph.gateway.endpoint.socket import SocketEndpoint +from trustgraph.gateway.running import Running + + +@pytest.fixture +def mock_auth(): + """Mock authentication service.""" + auth = MagicMock() + auth.permitted.return_value = True + return auth + + +@pytest.fixture +def mock_dispatcher_factory(): + """Mock dispatcher factory function.""" + async def dispatcher_factory(ws, running, match_info): + dispatcher = AsyncMock() + dispatcher.run = AsyncMock() + dispatcher.receive = AsyncMock() + dispatcher.destroy = AsyncMock() + return dispatcher + + return dispatcher_factory + + +@pytest.fixture +def socket_endpoint(mock_auth, mock_dispatcher_factory): + """Create SocketEndpoint for testing.""" + return SocketEndpoint( + endpoint_path="/test-socket", + auth=mock_auth, + dispatcher=mock_dispatcher_factory + ) + + +@pytest.fixture +def mock_websocket(): + """Mock websocket response.""" + ws = AsyncMock(spec=web.WebSocketResponse) + ws.prepare = AsyncMock() + ws.close = AsyncMock() + ws.closed = False + return ws + + +@pytest.fixture +def mock_request(): + """Mock HTTP request.""" + request = MagicMock() + request.query = {"token": "test-token"} + request.match_info = {} + return request + + +@pytest.mark.asyncio +async def test_listener_graceful_shutdown_on_close(): + """Test listener handles websocket close gracefully.""" + socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock()) + + # Mock websocket that closes after one message + ws = AsyncMock() + + # Create async iterator that yields one message then closes + async def mock_iterator(): + # Yield normal message + msg = MagicMock() + msg.type = WSMsgType.TEXT + yield msg + + # Yield close message + close_msg = MagicMock() + close_msg.type = WSMsgType.CLOSE + yield close_msg + + ws.__aiter__ = mock_iterator + + dispatcher = AsyncMock() + running = Running() + + with patch('asyncio.sleep') as mock_sleep: + await socket_endpoint.listener(ws, dispatcher, running) + + # Should have processed one message + dispatcher.receive.assert_called_once() + + # Should have initiated graceful shutdown + assert running.get() is False + + # Should have slept for grace period + mock_sleep.assert_called_once_with(1.0) + + +@pytest.mark.asyncio +async def test_handle_normal_flow(): + """Test normal websocket handling flow.""" + mock_auth = MagicMock() + mock_auth.permitted.return_value = True + + dispatcher_created = False + async def mock_dispatcher_factory(ws, running, match_info): + nonlocal dispatcher_created + dispatcher_created = True + dispatcher = AsyncMock() + dispatcher.destroy = AsyncMock() + return dispatcher + + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + + request = MagicMock() + request.query = {"token": "valid-token"} + request.match_info = {} + + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: + mock_ws = AsyncMock() + mock_ws.prepare = AsyncMock() + mock_ws.close = AsyncMock() + mock_ws.closed = False + mock_ws_class.return_value = mock_ws + + with patch('asyncio.TaskGroup') as mock_task_group: + # Mock task group context manager + mock_tg = AsyncMock() + mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) + mock_tg.__aexit__ = AsyncMock(return_value=None) + mock_tg.create_task = MagicMock(return_value=AsyncMock()) + mock_task_group.return_value = mock_tg + + result = await socket_endpoint.handle(request) + + # Should have created dispatcher + assert dispatcher_created is True + + # Should return websocket + assert result == mock_ws + + +@pytest.mark.asyncio +async def test_handle_exception_group_cleanup(): + """Test exception group triggers dispatcher cleanup.""" + mock_auth = MagicMock() + mock_auth.permitted.return_value = True + + mock_dispatcher = AsyncMock() + mock_dispatcher.destroy = AsyncMock() + + async def mock_dispatcher_factory(ws, running, match_info): + return mock_dispatcher + + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + + request = MagicMock() + request.query = {"token": "valid-token"} + request.match_info = {} + + # Mock TaskGroup to raise ExceptionGroup + class TestException(Exception): + pass + + exception_group = ExceptionGroup("Test exceptions", [TestException("test")]) + + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: + mock_ws = AsyncMock() + mock_ws.prepare = AsyncMock() + mock_ws.close = AsyncMock() + mock_ws.closed = False + mock_ws_class.return_value = mock_ws + + with patch('asyncio.TaskGroup') as mock_task_group: + mock_tg = AsyncMock() + mock_tg.__aenter__ = AsyncMock(side_effect=exception_group) + mock_tg.__aexit__ = AsyncMock(return_value=None) + mock_task_group.return_value = mock_tg + + with patch('asyncio.wait_for') as mock_wait_for: + mock_wait_for.return_value = None + + result = await socket_endpoint.handle(request) + + # Should have attempted graceful cleanup + mock_wait_for.assert_called_once() + + # Should have called destroy in finally block + assert mock_dispatcher.destroy.call_count >= 1 + + # Should have closed websocket + mock_ws.close.assert_called() + + +@pytest.mark.asyncio +async def test_handle_dispatcher_cleanup_timeout(): + """Test dispatcher cleanup with timeout.""" + mock_auth = MagicMock() + mock_auth.permitted.return_value = True + + # Mock dispatcher that takes long to destroy + mock_dispatcher = AsyncMock() + mock_dispatcher.destroy = AsyncMock() + + async def mock_dispatcher_factory(ws, running, match_info): + return mock_dispatcher + + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + + request = MagicMock() + request.query = {"token": "valid-token"} + request.match_info = {} + + # Mock TaskGroup to raise exception + exception_group = ExceptionGroup("Test", [Exception("test")]) + + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: + mock_ws = AsyncMock() + mock_ws.prepare = AsyncMock() + mock_ws.close = AsyncMock() + mock_ws.closed = False + mock_ws_class.return_value = mock_ws + + with patch('asyncio.TaskGroup') as mock_task_group: + mock_tg = AsyncMock() + mock_tg.__aenter__ = AsyncMock(side_effect=exception_group) + mock_tg.__aexit__ = AsyncMock(return_value=None) + mock_task_group.return_value = mock_tg + + # Mock asyncio.wait_for to raise TimeoutError + with patch('asyncio.wait_for') as mock_wait_for: + mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout") + + result = await socket_endpoint.handle(request) + + # Should have attempted cleanup with timeout + mock_wait_for.assert_called_once_with( + mock_dispatcher.destroy(), + timeout=5.0 + ) + + # Should still call destroy in finally block + assert mock_dispatcher.destroy.call_count >= 1 + + +@pytest.mark.asyncio +async def test_handle_unauthorized_request(): + """Test handling of unauthorized requests.""" + mock_auth = MagicMock() + mock_auth.permitted.return_value = False # Unauthorized + + socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock()) + + request = MagicMock() + request.query = {"token": "invalid-token"} + + result = await socket_endpoint.handle(request) + + # Should return HTTP 401 + assert isinstance(result, web.HTTPUnauthorized) + + # Should have checked permission + mock_auth.permitted.assert_called_once_with("invalid-token", "socket") + + +@pytest.mark.asyncio +async def test_handle_missing_token(): + """Test handling of requests with missing token.""" + mock_auth = MagicMock() + mock_auth.permitted.return_value = False + + socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock()) + + request = MagicMock() + request.query = {} # No token + + result = await socket_endpoint.handle(request) + + # Should return HTTP 401 + assert isinstance(result, web.HTTPUnauthorized) + + # Should have checked permission with empty token + mock_auth.permitted.assert_called_once_with("", "socket") + + +@pytest.mark.asyncio +async def test_handle_websocket_already_closed(): + """Test handling when websocket is already closed.""" + mock_auth = MagicMock() + mock_auth.permitted.return_value = True + + mock_dispatcher = AsyncMock() + mock_dispatcher.destroy = AsyncMock() + + async def mock_dispatcher_factory(ws, running, match_info): + return mock_dispatcher + + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + + request = MagicMock() + request.query = {"token": "valid-token"} + request.match_info = {} + + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: + mock_ws = AsyncMock() + mock_ws.prepare = AsyncMock() + mock_ws.close = AsyncMock() + mock_ws.closed = True # Already closed + mock_ws_class.return_value = mock_ws + + with patch('asyncio.TaskGroup') as mock_task_group: + mock_tg = AsyncMock() + mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) + mock_tg.__aexit__ = AsyncMock(return_value=None) + mock_tg.create_task = MagicMock(return_value=AsyncMock()) + mock_task_group.return_value = mock_tg + + result = await socket_endpoint.handle(request) + + # Should still have called destroy + mock_dispatcher.destroy.assert_called() + + # Should not attempt to close already closed websocket + mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True \ No newline at end of file