trustgraph/tests/unit/test_base/test_publisher_graceful_shutdown.py

330 lines
10 KiB
Python
Raw Normal View History

2025-09-20 16:00:37 +01:00
"""Unit tests for Publisher graceful shutdown functionality."""
import pytest
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.base.publisher import Publisher
@pytest.fixture
def mock_pulsar_client():
"""Mock Pulsar client for testing."""
client = MagicMock()
producer = AsyncMock()
producer.send = MagicMock()
producer.flush = MagicMock()
producer.close = MagicMock()
client.create_producer.return_value = producer
return client
@pytest.fixture
def publisher(mock_pulsar_client):
"""Create Publisher instance for testing."""
return Publisher(
client=mock_pulsar_client,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=2.0
)
@pytest.mark.asyncio
async def test_publisher_queue_drain():
"""Verify Publisher drains queue on shutdown."""
mock_client = MagicMock()
mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer
publisher = Publisher(
client=mock_client,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=1.0 # Shorter timeout for testing
)
# Don't start the actual run loop - just test the drain logic
# Fill queue with messages directly
for i in range(5):
await publisher.q.put((f"id-{i}", {"data": i}))
# Verify queue has messages
assert not publisher.q.empty()
# Mock the producer creation in run() method by patching
with patch.object(publisher, 'run') as mock_run:
# Create a realistic run implementation that processes the queue
async def mock_run_impl():
# Simulate the actual run logic for drain
producer = mock_producer
while not publisher.q.empty():
try:
id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.1)
producer.send(item, {"id": id})
except asyncio.TimeoutError:
break
producer.flush()
producer.close()
mock_run.side_effect = mock_run_impl
# Start and stop publisher
await publisher.start()
await publisher.stop()
# Verify all messages were sent
assert publisher.q.empty()
assert mock_producer.send.call_count == 5
mock_producer.flush.assert_called_once()
mock_producer.close.assert_called_once()
@pytest.mark.asyncio
async def test_publisher_rejects_messages_during_drain():
"""Verify Publisher rejects new messages during shutdown."""
mock_client = MagicMock()
mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer
publisher = Publisher(
client=mock_client,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=1.0
)
# Don't start the actual run loop
# Add one message directly
await publisher.q.put(("id-1", {"data": 1}))
# Start shutdown process manually
publisher.running = False
publisher.draining = True
# Try to send message during drain
with pytest.raises(RuntimeError, match="Publisher is shutting down"):
await publisher.send("id-2", {"data": 2})
@pytest.mark.asyncio
async def test_publisher_drain_timeout():
"""Verify Publisher respects drain timeout."""
mock_client = MagicMock()
mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer
publisher = Publisher(
client=mock_client,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=0.2 # Short timeout for testing
)
# Fill queue with many messages directly
for i in range(10):
await publisher.q.put((f"id-{i}", {"data": i}))
# Mock slow message processing
def slow_send(*args, **kwargs):
time.sleep(0.1) # Simulate slow send
mock_producer.send.side_effect = slow_send
with patch.object(publisher, 'run') as mock_run:
# Create a run implementation that respects timeout
async def mock_run_with_timeout():
producer = mock_producer
end_time = time.time() + publisher.drain_timeout
while not publisher.q.empty() and time.time() < end_time:
try:
id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.05)
producer.send(item, {"id": id})
except asyncio.TimeoutError:
break
producer.flush()
producer.close()
mock_run.side_effect = mock_run_with_timeout
start_time = time.time()
await publisher.start()
await publisher.stop()
end_time = time.time()
# Should timeout quickly
assert end_time - start_time < 1.0
# Should have called flush and close even with timeout
mock_producer.flush.assert_called_once()
mock_producer.close.assert_called_once()
@pytest.mark.asyncio
async def test_publisher_successful_drain():
"""Verify Publisher drains successfully under normal conditions."""
mock_client = MagicMock()
mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer
publisher = Publisher(
client=mock_client,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=2.0
)
# Add messages directly to queue
messages = []
for i in range(3):
msg = {"data": i}
await publisher.q.put((f"id-{i}", msg))
messages.append(msg)
with patch.object(publisher, 'run') as mock_run:
# Create a successful drain implementation
async def mock_successful_drain():
producer = mock_producer
processed = []
while not publisher.q.empty():
id, item = await publisher.q.get()
producer.send(item, {"id": id})
processed.append((id, item))
producer.flush()
producer.close()
return processed
mock_run.side_effect = mock_successful_drain
await publisher.start()
await publisher.stop()
# All messages should be sent
assert publisher.q.empty()
assert mock_producer.send.call_count == 3
# Verify correct messages were sent
sent_calls = mock_producer.send.call_args_list
for i, call in enumerate(sent_calls):
args, kwargs = call
assert args[0] == {"data": i} # message content
# Note: kwargs format depends on how send was called in mock
# Just verify message was sent with correct content
@pytest.mark.asyncio
async def test_publisher_state_transitions():
"""Test Publisher state transitions during graceful shutdown."""
mock_client = MagicMock()
mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer
publisher = Publisher(
client=mock_client,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=1.0
)
# Initial state
assert publisher.running is True
assert publisher.draining is False
# Add message directly
await publisher.q.put(("id-1", {"data": 1}))
with patch.object(publisher, 'run') as mock_run:
# Mock run that simulates state transitions
async def mock_run_with_states():
# Simulate drain process
publisher.running = False
publisher.draining = True
# Process messages
while not publisher.q.empty():
id, item = await publisher.q.get()
mock_producer.send(item, {"id": id})
# Complete drain
publisher.draining = False
mock_producer.flush()
mock_producer.close()
mock_run.side_effect = mock_run_with_states
await publisher.start()
await publisher.stop()
# Should have completed all state transitions
assert publisher.running is False
assert publisher.draining is False
mock_producer.send.assert_called_once()
mock_producer.flush.assert_called_once()
mock_producer.close.assert_called_once()
@pytest.mark.asyncio
async def test_publisher_exception_handling():
"""Test Publisher handles exceptions during drain gracefully."""
mock_client = MagicMock()
mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer
# Mock producer.send to raise exception on second call
call_count = 0
def failing_send(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 2:
raise Exception("Send failed")
mock_producer.send.side_effect = failing_send
publisher = Publisher(
client=mock_client,
topic="test-topic",
schema=dict,
max_size=10,
drain_timeout=1.0
)
# Add messages directly
await publisher.q.put(("id-1", {"data": 1}))
await publisher.q.put(("id-2", {"data": 2}))
with patch.object(publisher, 'run') as mock_run:
# Mock run that handles exceptions gracefully
async def mock_run_with_exceptions():
producer = mock_producer
while not publisher.q.empty():
try:
id, item = await publisher.q.get()
producer.send(item, {"id": id})
except Exception as e:
# Log exception but continue processing
continue
# Always call flush and close
producer.flush()
producer.close()
mock_run.side_effect = mock_run_with_exceptions
await publisher.start()
await publisher.stop()
# Should have attempted to send both messages
assert mock_producer.send.call_count == 2
mock_producer.flush.assert_called_once()
mock_producer.close.assert_called_once()