trustgraph/tests/unit/test_base/test_publisher_graceful_shutdown.py

330 lines
10 KiB
Python
Raw Normal View History

"""Unit tests for Publisher graceful shutdown functionality."""
import pytest
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.base.publisher import Publisher
@pytest.fixture
def mock_pulsar_backend():
"""Mock Pulsar backend for testing."""
backend = MagicMock()
producer = AsyncMock()
producer.send = MagicMock()
producer.flush = MagicMock()
producer.close = MagicMock()
backend.create_producer.return_value = producer
return backend
@pytest.fixture
def publisher(mock_pulsar_backend):
"""Create Publisher instance for testing."""
return Publisher(
backend=mock_pulsar_backend,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=2.0
)
@pytest.mark.asyncio
async def test_publisher_queue_drain():
"""Verify Publisher drains queue on shutdown."""
mock_backend = MagicMock()
mock_producer = MagicMock()
mock_backend.create_producer.return_value = mock_producer
publisher = Publisher(
backend=mock_backend,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=1.0 # Shorter timeout for testing
)
# Don't start the actual run loop - just test the drain logic
# Fill queue with messages directly
for i in range(5):
await publisher.q.put((f"id-{i}", {"data": i}))
# Verify queue has messages
assert not publisher.q.empty()
# Mock the producer creation in run() method by patching
with patch.object(publisher, 'run') as mock_run:
# Create a realistic run implementation that processes the queue
async def mock_run_impl():
# Simulate the actual run logic for drain
producer = mock_producer
while not publisher.q.empty():
try:
id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.1)
producer.send(item, {"id": id})
except asyncio.TimeoutError:
break
producer.flush()
producer.close()
mock_run.side_effect = mock_run_impl
# Start and stop publisher
await publisher.start()
await publisher.stop()
# Verify all messages were sent
assert publisher.q.empty()
assert mock_producer.send.call_count == 5
mock_producer.flush.assert_called_once()
mock_producer.close.assert_called_once()
@pytest.mark.asyncio
async def test_publisher_rejects_messages_during_drain():
"""Verify Publisher rejects new messages during shutdown."""
mock_backend = MagicMock()
mock_producer = MagicMock()
mock_backend.create_producer.return_value = mock_producer
publisher = Publisher(
backend=mock_backend,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=1.0
)
# Don't start the actual run loop
# Add one message directly
await publisher.q.put(("id-1", {"data": 1}))
# Start shutdown process manually
publisher.running = False
publisher.draining = True
# Try to send message during drain
with pytest.raises(RuntimeError, match="Publisher is shutting down"):
await publisher.send("id-2", {"data": 2})
@pytest.mark.asyncio
async def test_publisher_drain_timeout():
"""Verify Publisher respects drain timeout."""
mock_backend = MagicMock()
mock_producer = MagicMock()
mock_backend.create_producer.return_value = mock_producer
publisher = Publisher(
backend=mock_backend,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=0.2 # Short timeout for testing
)
# Fill queue with many messages directly
for i in range(10):
await publisher.q.put((f"id-{i}", {"data": i}))
# Mock slow message processing
def slow_send(*args, **kwargs):
time.sleep(0.1) # Simulate slow send
mock_producer.send.side_effect = slow_send
with patch.object(publisher, 'run') as mock_run:
# Create a run implementation that respects timeout
async def mock_run_with_timeout():
producer = mock_producer
end_time = time.time() + publisher.drain_timeout
while not publisher.q.empty() and time.time() < end_time:
try:
id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.05)
producer.send(item, {"id": id})
except asyncio.TimeoutError:
break
producer.flush()
producer.close()
mock_run.side_effect = mock_run_with_timeout
start_time = time.time()
await publisher.start()
await publisher.stop()
end_time = time.time()
# Should timeout quickly
assert end_time - start_time < 1.0
# Should have called flush and close even with timeout
mock_producer.flush.assert_called_once()
mock_producer.close.assert_called_once()
@pytest.mark.asyncio
async def test_publisher_successful_drain():
"""Verify Publisher drains successfully under normal conditions."""
mock_backend = MagicMock()
mock_producer = MagicMock()
mock_backend.create_producer.return_value = mock_producer
publisher = Publisher(
backend=mock_backend,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=2.0
)
# Add messages directly to queue
messages = []
for i in range(3):
msg = {"data": i}
await publisher.q.put((f"id-{i}", msg))
messages.append(msg)
with patch.object(publisher, 'run') as mock_run:
# Create a successful drain implementation
async def mock_successful_drain():
producer = mock_producer
processed = []
while not publisher.q.empty():
id, item = await publisher.q.get()
producer.send(item, {"id": id})
processed.append((id, item))
producer.flush()
producer.close()
return processed
mock_run.side_effect = mock_successful_drain
await publisher.start()
await publisher.stop()
# All messages should be sent
assert publisher.q.empty()
assert mock_producer.send.call_count == 3
# Verify correct messages were sent
sent_calls = mock_producer.send.call_args_list
for i, call in enumerate(sent_calls):
args, kwargs = call
assert args[0] == {"data": i} # message content
# Note: kwargs format depends on how send was called in mock
# Just verify message was sent with correct content
@pytest.mark.asyncio
async def test_publisher_state_transitions():
"""Test Publisher state transitions during graceful shutdown."""
mock_backend = MagicMock()
mock_producer = MagicMock()
mock_backend.create_producer.return_value = mock_producer
publisher = Publisher(
backend=mock_backend,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=1.0
)
# Initial state
assert publisher.running is True
assert publisher.draining is False
# Add message directly
await publisher.q.put(("id-1", {"data": 1}))
with patch.object(publisher, 'run') as mock_run:
# Mock run that simulates state transitions
async def mock_run_with_states():
# Simulate drain process
publisher.running = False
publisher.draining = True
# Process messages
while not publisher.q.empty():
id, item = await publisher.q.get()
mock_producer.send(item, {"id": id})
# Complete drain
publisher.draining = False
mock_producer.flush()
mock_producer.close()
mock_run.side_effect = mock_run_with_states
await publisher.start()
await publisher.stop()
# Should have completed all state transitions
assert publisher.running is False
assert publisher.draining is False
mock_producer.send.assert_called_once()
mock_producer.flush.assert_called_once()
mock_producer.close.assert_called_once()
@pytest.mark.asyncio
async def test_publisher_exception_handling():
"""Test Publisher handles exceptions during drain gracefully."""
mock_backend = MagicMock()
mock_producer = MagicMock()
mock_backend.create_producer.return_value = mock_producer
# Mock producer.send to raise exception on second call
call_count = 0
def failing_send(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 2:
raise Exception("Send failed")
mock_producer.send.side_effect = failing_send
publisher = Publisher(
backend=mock_backend,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=1.0
)
# Add messages directly
await publisher.q.put(("id-1", {"data": 1}))
await publisher.q.put(("id-2", {"data": 2}))
with patch.object(publisher, 'run') as mock_run:
# Mock run that handles exceptions gracefully
async def mock_run_with_exceptions():
producer = mock_producer
while not publisher.q.empty():
try:
id, item = await publisher.q.get()
producer.send(item, {"id": id})
except Exception as e:
# Log exception but continue processing
continue
# Always call flush and close
producer.flush()
producer.close()
mock_run.side_effect = mock_run_with_exceptions
await publisher.start()
await publisher.stop()
# Should have attempted to send both messages
assert mock_producer.send.call_count == 2
mock_producer.flush.assert_called_once()
mock_producer.close.assert_called_once()