mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
330 lines
No EOL
10 KiB
Python
330 lines
No EOL
10 KiB
Python
"""Unit tests for Publisher graceful shutdown functionality."""
|
|
|
|
import pytest
|
|
import asyncio
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from trustgraph.base.publisher import Publisher
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_pulsar_backend():
|
|
"""Mock Pulsar backend for testing."""
|
|
backend = MagicMock()
|
|
producer = AsyncMock()
|
|
producer.send = MagicMock()
|
|
producer.flush = MagicMock()
|
|
producer.close = MagicMock()
|
|
backend.create_producer.return_value = producer
|
|
return backend
|
|
|
|
|
|
@pytest.fixture
|
|
def publisher(mock_pulsar_backend):
|
|
"""Create Publisher instance for testing."""
|
|
return Publisher(
|
|
backend=mock_pulsar_backend,
|
|
topic="test-topic",
|
|
schema=dict,
|
|
max_size=10,
|
|
drain_timeout=2.0
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publisher_queue_drain():
|
|
"""Verify Publisher drains queue on shutdown."""
|
|
mock_backend = MagicMock()
|
|
mock_producer = MagicMock()
|
|
mock_backend.create_producer.return_value = mock_producer
|
|
|
|
publisher = Publisher(
|
|
backend=mock_backend,
|
|
topic="test-topic",
|
|
schema=dict,
|
|
max_size=10,
|
|
drain_timeout=1.0 # Shorter timeout for testing
|
|
)
|
|
|
|
# Don't start the actual run loop - just test the drain logic
|
|
# Fill queue with messages directly
|
|
for i in range(5):
|
|
await publisher.q.put((f"id-{i}", {"data": i}))
|
|
|
|
# Verify queue has messages
|
|
assert not publisher.q.empty()
|
|
|
|
# Mock the producer creation in run() method by patching
|
|
with patch.object(publisher, 'run') as mock_run:
|
|
# Create a realistic run implementation that processes the queue
|
|
async def mock_run_impl():
|
|
# Simulate the actual run logic for drain
|
|
producer = mock_producer
|
|
while not publisher.q.empty():
|
|
try:
|
|
id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.1)
|
|
producer.send(item, {"id": id})
|
|
except asyncio.TimeoutError:
|
|
break
|
|
producer.flush()
|
|
producer.close()
|
|
|
|
mock_run.side_effect = mock_run_impl
|
|
|
|
# Start and stop publisher
|
|
await publisher.start()
|
|
await publisher.stop()
|
|
|
|
# Verify all messages were sent
|
|
assert publisher.q.empty()
|
|
assert mock_producer.send.call_count == 5
|
|
mock_producer.flush.assert_called_once()
|
|
mock_producer.close.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publisher_rejects_messages_during_drain():
|
|
"""Verify Publisher rejects new messages during shutdown."""
|
|
mock_backend = MagicMock()
|
|
mock_producer = MagicMock()
|
|
mock_backend.create_producer.return_value = mock_producer
|
|
|
|
publisher = Publisher(
|
|
backend=mock_backend,
|
|
topic="test-topic",
|
|
schema=dict,
|
|
max_size=10,
|
|
drain_timeout=1.0
|
|
)
|
|
|
|
# Don't start the actual run loop
|
|
# Add one message directly
|
|
await publisher.q.put(("id-1", {"data": 1}))
|
|
|
|
# Start shutdown process manually
|
|
publisher.running = False
|
|
publisher.draining = True
|
|
|
|
# Try to send message during drain
|
|
with pytest.raises(RuntimeError, match="Publisher is shutting down"):
|
|
await publisher.send("id-2", {"data": 2})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publisher_drain_timeout():
|
|
"""Verify Publisher respects drain timeout."""
|
|
mock_backend = MagicMock()
|
|
mock_producer = MagicMock()
|
|
mock_backend.create_producer.return_value = mock_producer
|
|
|
|
publisher = Publisher(
|
|
backend=mock_backend,
|
|
topic="test-topic",
|
|
schema=dict,
|
|
max_size=10,
|
|
drain_timeout=0.2 # Short timeout for testing
|
|
)
|
|
|
|
# Fill queue with many messages directly
|
|
for i in range(10):
|
|
await publisher.q.put((f"id-{i}", {"data": i}))
|
|
|
|
# Mock slow message processing
|
|
def slow_send(*args, **kwargs):
|
|
time.sleep(0.1) # Simulate slow send
|
|
|
|
mock_producer.send.side_effect = slow_send
|
|
|
|
with patch.object(publisher, 'run') as mock_run:
|
|
# Create a run implementation that respects timeout
|
|
async def mock_run_with_timeout():
|
|
producer = mock_producer
|
|
end_time = time.time() + publisher.drain_timeout
|
|
|
|
while not publisher.q.empty() and time.time() < end_time:
|
|
try:
|
|
id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.05)
|
|
producer.send(item, {"id": id})
|
|
except asyncio.TimeoutError:
|
|
break
|
|
|
|
producer.flush()
|
|
producer.close()
|
|
|
|
mock_run.side_effect = mock_run_with_timeout
|
|
|
|
start_time = time.time()
|
|
await publisher.start()
|
|
await publisher.stop()
|
|
end_time = time.time()
|
|
|
|
# Should timeout quickly
|
|
assert end_time - start_time < 1.0
|
|
|
|
# Should have called flush and close even with timeout
|
|
mock_producer.flush.assert_called_once()
|
|
mock_producer.close.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publisher_successful_drain():
|
|
"""Verify Publisher drains successfully under normal conditions."""
|
|
mock_backend = MagicMock()
|
|
mock_producer = MagicMock()
|
|
mock_backend.create_producer.return_value = mock_producer
|
|
|
|
publisher = Publisher(
|
|
backend=mock_backend,
|
|
topic="test-topic",
|
|
schema=dict,
|
|
max_size=10,
|
|
drain_timeout=2.0
|
|
)
|
|
|
|
# Add messages directly to queue
|
|
messages = []
|
|
for i in range(3):
|
|
msg = {"data": i}
|
|
await publisher.q.put((f"id-{i}", msg))
|
|
messages.append(msg)
|
|
|
|
with patch.object(publisher, 'run') as mock_run:
|
|
# Create a successful drain implementation
|
|
async def mock_successful_drain():
|
|
producer = mock_producer
|
|
processed = []
|
|
|
|
while not publisher.q.empty():
|
|
id, item = await publisher.q.get()
|
|
producer.send(item, {"id": id})
|
|
processed.append((id, item))
|
|
|
|
producer.flush()
|
|
producer.close()
|
|
return processed
|
|
|
|
mock_run.side_effect = mock_successful_drain
|
|
|
|
await publisher.start()
|
|
await publisher.stop()
|
|
|
|
# All messages should be sent
|
|
assert publisher.q.empty()
|
|
assert mock_producer.send.call_count == 3
|
|
|
|
# Verify correct messages were sent
|
|
sent_calls = mock_producer.send.call_args_list
|
|
for i, call in enumerate(sent_calls):
|
|
args, kwargs = call
|
|
assert args[0] == {"data": i} # message content
|
|
# Note: kwargs format depends on how send was called in mock
|
|
# Just verify message was sent with correct content
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publisher_state_transitions():
|
|
"""Test Publisher state transitions during graceful shutdown."""
|
|
mock_backend = MagicMock()
|
|
mock_producer = MagicMock()
|
|
mock_backend.create_producer.return_value = mock_producer
|
|
|
|
publisher = Publisher(
|
|
backend=mock_backend,
|
|
topic="test-topic",
|
|
schema=dict,
|
|
max_size=10,
|
|
drain_timeout=1.0
|
|
)
|
|
|
|
# Initial state
|
|
assert publisher.running is True
|
|
assert publisher.draining is False
|
|
|
|
# Add message directly
|
|
await publisher.q.put(("id-1", {"data": 1}))
|
|
|
|
with patch.object(publisher, 'run') as mock_run:
|
|
# Mock run that simulates state transitions
|
|
async def mock_run_with_states():
|
|
# Simulate drain process
|
|
publisher.running = False
|
|
publisher.draining = True
|
|
|
|
# Process messages
|
|
while not publisher.q.empty():
|
|
id, item = await publisher.q.get()
|
|
mock_producer.send(item, {"id": id})
|
|
|
|
# Complete drain
|
|
publisher.draining = False
|
|
mock_producer.flush()
|
|
mock_producer.close()
|
|
|
|
mock_run.side_effect = mock_run_with_states
|
|
|
|
await publisher.start()
|
|
await publisher.stop()
|
|
|
|
# Should have completed all state transitions
|
|
assert publisher.running is False
|
|
assert publisher.draining is False
|
|
mock_producer.send.assert_called_once()
|
|
mock_producer.flush.assert_called_once()
|
|
mock_producer.close.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publisher_exception_handling():
|
|
"""Test Publisher handles exceptions during drain gracefully."""
|
|
mock_backend = MagicMock()
|
|
mock_producer = MagicMock()
|
|
mock_backend.create_producer.return_value = mock_producer
|
|
|
|
# Mock producer.send to raise exception on second call
|
|
call_count = 0
|
|
def failing_send(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 2:
|
|
raise Exception("Send failed")
|
|
|
|
mock_producer.send.side_effect = failing_send
|
|
|
|
publisher = Publisher(
|
|
backend=mock_backend,
|
|
topic="test-topic",
|
|
schema=dict,
|
|
max_size=10,
|
|
drain_timeout=1.0
|
|
)
|
|
|
|
# Add messages directly
|
|
await publisher.q.put(("id-1", {"data": 1}))
|
|
await publisher.q.put(("id-2", {"data": 2}))
|
|
|
|
with patch.object(publisher, 'run') as mock_run:
|
|
# Mock run that handles exceptions gracefully
|
|
async def mock_run_with_exceptions():
|
|
producer = mock_producer
|
|
|
|
while not publisher.q.empty():
|
|
try:
|
|
id, item = await publisher.q.get()
|
|
producer.send(item, {"id": id})
|
|
except Exception as e:
|
|
# Log exception but continue processing
|
|
continue
|
|
|
|
# Always call flush and close
|
|
producer.flush()
|
|
producer.close()
|
|
|
|
mock_run.side_effect = mock_run_with_exceptions
|
|
|
|
await publisher.start()
|
|
await publisher.stop()
|
|
|
|
# Should have attempted to send both messages
|
|
assert mock_producer.send.call_count == 2
|
|
mock_producer.flush.assert_called_once()
|
|
mock_producer.close.assert_called_once() |