mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-26 15:55:16 +02:00
Adding tests
This commit is contained in:
parent
5f31b9efd8
commit
34ac5279bb
4 changed files with 1423 additions and 0 deletions
454
tests/integration/test_import_export_graceful_shutdown.py
Normal file
454
tests/integration/test_import_export_graceful_shutdown.py
Normal file
|
|
@ -0,0 +1,454 @@
|
||||||
|
"""Integration tests for import/export graceful shutdown functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from aiohttp import web, WSMsgType, ClientWebSocketResponse
|
||||||
|
from trustgraph.gateway.dispatch.triples_import import TriplesImport
|
||||||
|
from trustgraph.gateway.dispatch.triples_export import TriplesExport
|
||||||
|
from trustgraph.gateway.running import Running
|
||||||
|
from trustgraph.base.publisher import Publisher
|
||||||
|
from trustgraph.base.subscriber import Subscriber
|
||||||
|
|
||||||
|
|
||||||
|
class MockPulsarMessage:
|
||||||
|
"""Mock Pulsar message for testing."""
|
||||||
|
|
||||||
|
def __init__(self, data, message_id="test-id"):
|
||||||
|
self._data = data
|
||||||
|
self._message_id = message_id
|
||||||
|
self._properties = {"id": message_id}
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def properties(self):
|
||||||
|
return self._properties
|
||||||
|
|
||||||
|
|
||||||
|
class MockWebSocket:
|
||||||
|
"""Mock WebSocket for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.messages = []
|
||||||
|
self.closed = False
|
||||||
|
self._close_called = False
|
||||||
|
|
||||||
|
async def send_json(self, data):
|
||||||
|
if self.closed:
|
||||||
|
raise Exception("WebSocket is closed")
|
||||||
|
self.messages.append(data)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self._close_called = True
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
"""Mock message json() method."""
|
||||||
|
return {
|
||||||
|
"metadata": {
|
||||||
|
"id": "test-id",
|
||||||
|
"metadata": {},
|
||||||
|
"user": "test-user",
|
||||||
|
"collection": "test-collection"
|
||||||
|
},
|
||||||
|
"triples": [["subject", "predicate", "object"]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pulsar_client():
|
||||||
|
"""Mock Pulsar client for integration testing."""
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Mock producer
|
||||||
|
producer = MagicMock()
|
||||||
|
producer.send = MagicMock()
|
||||||
|
producer.flush = MagicMock()
|
||||||
|
producer.close = MagicMock()
|
||||||
|
client.create_producer.return_value = producer
|
||||||
|
|
||||||
|
# Mock consumer
|
||||||
|
consumer = MagicMock()
|
||||||
|
consumer.receive = AsyncMock()
|
||||||
|
consumer.acknowledge = MagicMock()
|
||||||
|
consumer.negative_acknowledge = MagicMock()
|
||||||
|
consumer.pause_message_listener = MagicMock()
|
||||||
|
consumer.unsubscribe = MagicMock()
|
||||||
|
consumer.close = MagicMock()
|
||||||
|
client.subscribe.return_value = consumer
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_graceful_shutdown_integration():
|
||||||
|
"""Test import path handles shutdown gracefully with real message flow."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_producer = MagicMock()
|
||||||
|
mock_client.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
|
# Track sent messages
|
||||||
|
sent_messages = []
|
||||||
|
def track_send(message, properties=None):
|
||||||
|
sent_messages.append((message, properties))
|
||||||
|
|
||||||
|
mock_producer.send.side_effect = track_send
|
||||||
|
|
||||||
|
ws = MockWebSocket()
|
||||||
|
running = Running()
|
||||||
|
|
||||||
|
# Create import handler
|
||||||
|
import_handler = TriplesImport(
|
||||||
|
ws=ws,
|
||||||
|
running=running,
|
||||||
|
pulsar_client=mock_client,
|
||||||
|
queue="test-triples-import"
|
||||||
|
)
|
||||||
|
|
||||||
|
await import_handler.start()
|
||||||
|
|
||||||
|
# Send multiple messages rapidly
|
||||||
|
messages = []
|
||||||
|
for i in range(10):
|
||||||
|
msg_data = {
|
||||||
|
"metadata": {
|
||||||
|
"id": f"msg-{i}",
|
||||||
|
"metadata": {},
|
||||||
|
"user": "test-user",
|
||||||
|
"collection": "test-collection"
|
||||||
|
},
|
||||||
|
"triples": [[f"subject-{i}", "predicate", f"object-{i}"]]
|
||||||
|
}
|
||||||
|
messages.append(msg_data)
|
||||||
|
|
||||||
|
# Create mock message with json() method
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.json.return_value = msg_data
|
||||||
|
|
||||||
|
await import_handler.receive(mock_msg)
|
||||||
|
|
||||||
|
# Allow brief processing time
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Shutdown while messages may be in flight
|
||||||
|
await import_handler.destroy()
|
||||||
|
|
||||||
|
# Verify all messages reached producer
|
||||||
|
assert len(sent_messages) == 10
|
||||||
|
|
||||||
|
# Verify proper shutdown order was followed
|
||||||
|
mock_producer.flush.assert_called_once()
|
||||||
|
mock_producer.close.assert_called_once()
|
||||||
|
|
||||||
|
# Verify messages have correct content
|
||||||
|
for i, (message, properties) in enumerate(sent_messages):
|
||||||
|
assert message.metadata.id == f"msg-{i}"
|
||||||
|
assert len(message.triples) == 1
|
||||||
|
assert message.triples[0][0] == f"subject-{i}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_export_no_message_loss_integration():
|
||||||
|
"""Test export path doesn't lose acknowledged messages."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
# Create test messages
|
||||||
|
test_messages = []
|
||||||
|
for i in range(20):
|
||||||
|
msg_data = {
|
||||||
|
"metadata": {
|
||||||
|
"id": f"export-msg-{i}",
|
||||||
|
"metadata": {},
|
||||||
|
"user": "test-user",
|
||||||
|
"collection": "test-collection"
|
||||||
|
},
|
||||||
|
"triples": [[f"export-subject-{i}", "predicate", f"export-object-{i}"]]
|
||||||
|
}
|
||||||
|
test_messages.append(MockPulsarMessage(msg_data, f"export-msg-{i}"))
|
||||||
|
|
||||||
|
# Mock consumer to provide messages
|
||||||
|
message_iter = iter(test_messages)
|
||||||
|
async def mock_receive():
|
||||||
|
try:
|
||||||
|
return next(message_iter)
|
||||||
|
except StopIteration:
|
||||||
|
# Simulate no more messages
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
mock_consumer.receive = mock_receive
|
||||||
|
|
||||||
|
ws = MockWebSocket()
|
||||||
|
running = Running()
|
||||||
|
|
||||||
|
# Create export handler
|
||||||
|
export_handler = TriplesExport(
|
||||||
|
ws=ws,
|
||||||
|
running=running,
|
||||||
|
pulsar_client=mock_client,
|
||||||
|
queue="test-triples-export",
|
||||||
|
consumer="test-consumer",
|
||||||
|
subscriber="test-subscriber"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start export in background
|
||||||
|
export_task = asyncio.create_task(export_handler.run())
|
||||||
|
|
||||||
|
# Allow some messages to be processed
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Verify some messages were sent to websocket
|
||||||
|
initial_count = len(ws.messages)
|
||||||
|
assert initial_count > 0
|
||||||
|
|
||||||
|
# Force shutdown
|
||||||
|
await export_handler.destroy()
|
||||||
|
|
||||||
|
# Wait for export task to complete
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(export_task, timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
export_task.cancel()
|
||||||
|
|
||||||
|
# Verify websocket was closed
|
||||||
|
assert ws._close_called is True
|
||||||
|
|
||||||
|
# Verify messages that were acknowledged were actually sent
|
||||||
|
final_count = len(ws.messages)
|
||||||
|
assert final_count >= initial_count
|
||||||
|
|
||||||
|
# Verify no partial/corrupted messages
|
||||||
|
for msg in ws.messages:
|
||||||
|
assert "metadata" in msg
|
||||||
|
assert "triples" in msg
|
||||||
|
assert msg["metadata"]["id"].startswith("export-msg-")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_import_export_shutdown():
|
||||||
|
"""Test concurrent import and export shutdown scenarios."""
|
||||||
|
# Setup mock clients
|
||||||
|
import_client = MagicMock()
|
||||||
|
export_client = MagicMock()
|
||||||
|
|
||||||
|
import_producer = MagicMock()
|
||||||
|
export_consumer = MagicMock()
|
||||||
|
|
||||||
|
import_client.create_producer.return_value = import_producer
|
||||||
|
export_client.subscribe.return_value = export_consumer
|
||||||
|
|
||||||
|
# Track operations
|
||||||
|
import_operations = []
|
||||||
|
export_operations = []
|
||||||
|
|
||||||
|
def track_import_send(message, properties=None):
|
||||||
|
import_operations.append(("send", message.metadata.id))
|
||||||
|
|
||||||
|
def track_import_flush():
|
||||||
|
import_operations.append(("flush",))
|
||||||
|
|
||||||
|
def track_export_ack(msg):
|
||||||
|
export_operations.append(("ack", msg.properties()["id"]))
|
||||||
|
|
||||||
|
import_producer.send.side_effect = track_import_send
|
||||||
|
import_producer.flush.side_effect = track_import_flush
|
||||||
|
export_consumer.acknowledge.side_effect = track_export_ack
|
||||||
|
|
||||||
|
# Create handlers
|
||||||
|
import_ws = MockWebSocket()
|
||||||
|
export_ws = MockWebSocket()
|
||||||
|
import_running = Running()
|
||||||
|
export_running = Running()
|
||||||
|
|
||||||
|
import_handler = TriplesImport(
|
||||||
|
ws=import_ws,
|
||||||
|
running=import_running,
|
||||||
|
pulsar_client=import_client,
|
||||||
|
queue="concurrent-import"
|
||||||
|
)
|
||||||
|
|
||||||
|
export_handler = TriplesExport(
|
||||||
|
ws=export_ws,
|
||||||
|
running=export_running,
|
||||||
|
pulsar_client=export_client,
|
||||||
|
queue="concurrent-export",
|
||||||
|
consumer="concurrent-consumer",
|
||||||
|
subscriber="concurrent-subscriber"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start both handlers
|
||||||
|
await import_handler.start()
|
||||||
|
|
||||||
|
# Send messages to import
|
||||||
|
for i in range(5):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.json.return_value = {
|
||||||
|
"metadata": {
|
||||||
|
"id": f"concurrent-{i}",
|
||||||
|
"metadata": {},
|
||||||
|
"user": "test-user",
|
||||||
|
"collection": "test-collection"
|
||||||
|
},
|
||||||
|
"triples": [[f"concurrent-subject-{i}", "predicate", "object"]]
|
||||||
|
}
|
||||||
|
await import_handler.receive(msg)
|
||||||
|
|
||||||
|
# Shutdown both concurrently
|
||||||
|
import_shutdown = asyncio.create_task(import_handler.destroy())
|
||||||
|
export_shutdown = asyncio.create_task(export_handler.destroy())
|
||||||
|
|
||||||
|
await asyncio.gather(import_shutdown, export_shutdown)
|
||||||
|
|
||||||
|
# Verify import operations completed properly
|
||||||
|
assert len(import_operations) == 6 # 5 sends + 1 flush
|
||||||
|
assert ("flush",) in import_operations
|
||||||
|
|
||||||
|
# Verify all import messages were processed
|
||||||
|
send_ops = [op for op in import_operations if op[0] == "send"]
|
||||||
|
assert len(send_ops) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_close_during_message_processing():
|
||||||
|
"""Test graceful handling when websocket closes during active message processing."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_producer = MagicMock()
|
||||||
|
mock_client.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
|
# Simulate slow message processing
|
||||||
|
processed_messages = []
|
||||||
|
async def slow_send(message, properties=None):
|
||||||
|
processed_messages.append(message.metadata.id)
|
||||||
|
await asyncio.sleep(0.1) # Simulate processing delay
|
||||||
|
|
||||||
|
mock_producer.send.side_effect = slow_send
|
||||||
|
|
||||||
|
ws = MockWebSocket()
|
||||||
|
running = Running()
|
||||||
|
|
||||||
|
import_handler = TriplesImport(
|
||||||
|
ws=ws,
|
||||||
|
running=running,
|
||||||
|
pulsar_client=mock_client,
|
||||||
|
queue="slow-processing-import"
|
||||||
|
)
|
||||||
|
|
||||||
|
await import_handler.start()
|
||||||
|
|
||||||
|
# Send many messages rapidly
|
||||||
|
message_tasks = []
|
||||||
|
for i in range(10):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.json.return_value = {
|
||||||
|
"metadata": {
|
||||||
|
"id": f"slow-msg-{i}",
|
||||||
|
"metadata": {},
|
||||||
|
"user": "test-user",
|
||||||
|
"collection": "test-collection"
|
||||||
|
},
|
||||||
|
"triples": [[f"slow-subject-{i}", "predicate", "object"]]
|
||||||
|
}
|
||||||
|
task = asyncio.create_task(import_handler.receive(msg))
|
||||||
|
message_tasks.append(task)
|
||||||
|
|
||||||
|
# Allow some processing to start
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
# Close websocket while messages are being processed
|
||||||
|
ws.closed = True
|
||||||
|
|
||||||
|
# Shutdown handler
|
||||||
|
await import_handler.destroy()
|
||||||
|
|
||||||
|
# Wait for all message tasks to complete
|
||||||
|
await asyncio.gather(*message_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Verify that messages that were being processed completed
|
||||||
|
# (graceful shutdown should allow in-flight processing to finish)
|
||||||
|
assert len(processed_messages) > 0
|
||||||
|
|
||||||
|
# Verify producer was properly flushed and closed
|
||||||
|
mock_producer.flush.assert_called_once()
|
||||||
|
mock_producer.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_backpressure_during_shutdown():
|
||||||
|
"""Test graceful shutdown under backpressure conditions."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
# Create messages that will cause backpressure
|
||||||
|
large_messages = []
|
||||||
|
for i in range(50):
|
||||||
|
msg_data = {
|
||||||
|
"metadata": {
|
||||||
|
"id": f"large-msg-{i}",
|
||||||
|
"metadata": {"large_field": "x" * 1000}, # Large metadata
|
||||||
|
"user": "test-user",
|
||||||
|
"collection": "test-collection"
|
||||||
|
},
|
||||||
|
"triples": [[f"large-subject-{i}", "predicate", f"large-object-{i}"]]
|
||||||
|
}
|
||||||
|
large_messages.append(MockPulsarMessage(msg_data, f"large-msg-{i}"))
|
||||||
|
|
||||||
|
# Mock slow websocket
|
||||||
|
class SlowWebSocket(MockWebSocket):
|
||||||
|
async def send_json(self, data):
|
||||||
|
await asyncio.sleep(0.02) # Slow send
|
||||||
|
await super().send_json(data)
|
||||||
|
|
||||||
|
ws = SlowWebSocket()
|
||||||
|
running = Running()
|
||||||
|
|
||||||
|
export_handler = TriplesExport(
|
||||||
|
ws=ws,
|
||||||
|
running=running,
|
||||||
|
pulsar_client=mock_client,
|
||||||
|
queue="backpressure-export",
|
||||||
|
consumer="backpressure-consumer",
|
||||||
|
subscriber="backpressure-subscriber"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock consumer with backpressure
|
||||||
|
message_queue = asyncio.Queue(maxsize=5) # Small queue
|
||||||
|
for msg in large_messages[:10]: # Only add first 10
|
||||||
|
await message_queue.put(msg)
|
||||||
|
|
||||||
|
async def mock_receive_with_backpressure():
|
||||||
|
return await message_queue.get()
|
||||||
|
|
||||||
|
mock_consumer.receive = mock_receive_with_backpressure
|
||||||
|
|
||||||
|
# Start export task
|
||||||
|
export_task = asyncio.create_task(export_handler.run())
|
||||||
|
|
||||||
|
# Allow some processing
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
# Shutdown under backpressure
|
||||||
|
shutdown_start = time.time()
|
||||||
|
await export_handler.destroy()
|
||||||
|
shutdown_duration = time.time() - shutdown_start
|
||||||
|
|
||||||
|
# Cancel export task
|
||||||
|
export_task.cancel()
|
||||||
|
try:
|
||||||
|
await export_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify graceful shutdown completed within reasonable time
|
||||||
|
assert shutdown_duration < 10.0 # Should not hang indefinitely
|
||||||
|
|
||||||
|
# Verify some messages were processed before shutdown
|
||||||
|
assert len(ws.messages) > 0
|
||||||
|
|
||||||
|
# Verify websocket was closed
|
||||||
|
assert ws._close_called is True
|
||||||
330
tests/unit/test_base/test_publisher_graceful_shutdown.py
Normal file
330
tests/unit/test_base/test_publisher_graceful_shutdown.py
Normal file
|
|
@ -0,0 +1,330 @@
|
||||||
|
"""Unit tests for Publisher graceful shutdown functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from trustgraph.base.publisher import Publisher
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pulsar_client():
|
||||||
|
"""Mock Pulsar client for testing."""
|
||||||
|
client = MagicMock()
|
||||||
|
producer = AsyncMock()
|
||||||
|
producer.send = MagicMock()
|
||||||
|
producer.flush = MagicMock()
|
||||||
|
producer.close = MagicMock()
|
||||||
|
client.create_producer.return_value = producer
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def publisher(mock_pulsar_client):
|
||||||
|
"""Create Publisher instance for testing."""
|
||||||
|
return Publisher(
|
||||||
|
client=mock_pulsar_client,
|
||||||
|
topic="test-topic",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publisher_queue_drain():
|
||||||
|
"""Verify Publisher drains queue on shutdown."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_producer = MagicMock()
|
||||||
|
mock_client.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
|
publisher = Publisher(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=1.0 # Shorter timeout for testing
|
||||||
|
)
|
||||||
|
|
||||||
|
# Don't start the actual run loop - just test the drain logic
|
||||||
|
# Fill queue with messages directly
|
||||||
|
for i in range(5):
|
||||||
|
await publisher.q.put((f"id-{i}", {"data": i}))
|
||||||
|
|
||||||
|
# Verify queue has messages
|
||||||
|
assert not publisher.q.empty()
|
||||||
|
|
||||||
|
# Mock the producer creation in run() method by patching
|
||||||
|
with patch.object(publisher, 'run') as mock_run:
|
||||||
|
# Create a realistic run implementation that processes the queue
|
||||||
|
async def mock_run_impl():
|
||||||
|
# Simulate the actual run logic for drain
|
||||||
|
producer = mock_producer
|
||||||
|
while not publisher.q.empty():
|
||||||
|
try:
|
||||||
|
id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.1)
|
||||||
|
producer.send(item, {"id": id})
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
producer.flush()
|
||||||
|
producer.close()
|
||||||
|
|
||||||
|
mock_run.side_effect = mock_run_impl
|
||||||
|
|
||||||
|
# Start and stop publisher
|
||||||
|
await publisher.start()
|
||||||
|
await publisher.stop()
|
||||||
|
|
||||||
|
# Verify all messages were sent
|
||||||
|
assert publisher.q.empty()
|
||||||
|
assert mock_producer.send.call_count == 5
|
||||||
|
mock_producer.flush.assert_called_once()
|
||||||
|
mock_producer.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publisher_rejects_messages_during_drain():
|
||||||
|
"""Verify Publisher rejects new messages during shutdown."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_producer = MagicMock()
|
||||||
|
mock_client.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
|
publisher = Publisher(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Don't start the actual run loop
|
||||||
|
# Add one message directly
|
||||||
|
await publisher.q.put(("id-1", {"data": 1}))
|
||||||
|
|
||||||
|
# Start shutdown process manually
|
||||||
|
publisher.running = False
|
||||||
|
publisher.draining = True
|
||||||
|
|
||||||
|
# Try to send message during drain
|
||||||
|
with pytest.raises(RuntimeError, match="Publisher is shutting down"):
|
||||||
|
await publisher.send("id-2", {"data": 2})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publisher_drain_timeout():
|
||||||
|
"""Verify Publisher respects drain timeout."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_producer = MagicMock()
|
||||||
|
mock_client.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
|
publisher = Publisher(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=0.2 # Short timeout for testing
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fill queue with many messages directly
|
||||||
|
for i in range(10):
|
||||||
|
await publisher.q.put((f"id-{i}", {"data": i}))
|
||||||
|
|
||||||
|
# Mock slow message processing
|
||||||
|
def slow_send(*args, **kwargs):
|
||||||
|
time.sleep(0.1) # Simulate slow send
|
||||||
|
|
||||||
|
mock_producer.send.side_effect = slow_send
|
||||||
|
|
||||||
|
with patch.object(publisher, 'run') as mock_run:
|
||||||
|
# Create a run implementation that respects timeout
|
||||||
|
async def mock_run_with_timeout():
|
||||||
|
producer = mock_producer
|
||||||
|
end_time = time.time() + publisher.drain_timeout
|
||||||
|
|
||||||
|
while not publisher.q.empty() and time.time() < end_time:
|
||||||
|
try:
|
||||||
|
id, item = await asyncio.wait_for(publisher.q.get(), timeout=0.05)
|
||||||
|
producer.send(item, {"id": id})
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
producer.flush()
|
||||||
|
producer.close()
|
||||||
|
|
||||||
|
mock_run.side_effect = mock_run_with_timeout
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
await publisher.start()
|
||||||
|
await publisher.stop()
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
# Should timeout quickly
|
||||||
|
assert end_time - start_time < 1.0
|
||||||
|
|
||||||
|
# Should have called flush and close even with timeout
|
||||||
|
mock_producer.flush.assert_called_once()
|
||||||
|
mock_producer.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publisher_successful_drain():
|
||||||
|
"""Verify Publisher drains successfully under normal conditions."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_producer = MagicMock()
|
||||||
|
mock_client.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
|
publisher = Publisher(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add messages directly to queue
|
||||||
|
messages = []
|
||||||
|
for i in range(3):
|
||||||
|
msg = {"data": i}
|
||||||
|
await publisher.q.put((f"id-{i}", msg))
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
with patch.object(publisher, 'run') as mock_run:
|
||||||
|
# Create a successful drain implementation
|
||||||
|
async def mock_successful_drain():
|
||||||
|
producer = mock_producer
|
||||||
|
processed = []
|
||||||
|
|
||||||
|
while not publisher.q.empty():
|
||||||
|
id, item = await publisher.q.get()
|
||||||
|
producer.send(item, {"id": id})
|
||||||
|
processed.append((id, item))
|
||||||
|
|
||||||
|
producer.flush()
|
||||||
|
producer.close()
|
||||||
|
return processed
|
||||||
|
|
||||||
|
mock_run.side_effect = mock_successful_drain
|
||||||
|
|
||||||
|
await publisher.start()
|
||||||
|
await publisher.stop()
|
||||||
|
|
||||||
|
# All messages should be sent
|
||||||
|
assert publisher.q.empty()
|
||||||
|
assert mock_producer.send.call_count == 3
|
||||||
|
|
||||||
|
# Verify correct messages were sent
|
||||||
|
sent_calls = mock_producer.send.call_args_list
|
||||||
|
for i, call in enumerate(sent_calls):
|
||||||
|
args, kwargs = call
|
||||||
|
assert args[0] == {"data": i} # message content
|
||||||
|
# Note: kwargs format depends on how send was called in mock
|
||||||
|
# Just verify message was sent with correct content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publisher_state_transitions():
|
||||||
|
"""Test Publisher state transitions during graceful shutdown."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_producer = MagicMock()
|
||||||
|
mock_client.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
|
publisher = Publisher(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initial state
|
||||||
|
assert publisher.running is True
|
||||||
|
assert publisher.draining is False
|
||||||
|
|
||||||
|
# Add message directly
|
||||||
|
await publisher.q.put(("id-1", {"data": 1}))
|
||||||
|
|
||||||
|
with patch.object(publisher, 'run') as mock_run:
|
||||||
|
# Mock run that simulates state transitions
|
||||||
|
async def mock_run_with_states():
|
||||||
|
# Simulate drain process
|
||||||
|
publisher.running = False
|
||||||
|
publisher.draining = True
|
||||||
|
|
||||||
|
# Process messages
|
||||||
|
while not publisher.q.empty():
|
||||||
|
id, item = await publisher.q.get()
|
||||||
|
mock_producer.send(item, {"id": id})
|
||||||
|
|
||||||
|
# Complete drain
|
||||||
|
publisher.draining = False
|
||||||
|
mock_producer.flush()
|
||||||
|
mock_producer.close()
|
||||||
|
|
||||||
|
mock_run.side_effect = mock_run_with_states
|
||||||
|
|
||||||
|
await publisher.start()
|
||||||
|
await publisher.stop()
|
||||||
|
|
||||||
|
# Should have completed all state transitions
|
||||||
|
assert publisher.running is False
|
||||||
|
assert publisher.draining is False
|
||||||
|
mock_producer.send.assert_called_once()
|
||||||
|
mock_producer.flush.assert_called_once()
|
||||||
|
mock_producer.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publisher_exception_handling():
|
||||||
|
"""Test Publisher handles exceptions during drain gracefully."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_producer = MagicMock()
|
||||||
|
mock_client.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
|
# Mock producer.send to raise exception on second call
|
||||||
|
call_count = 0
|
||||||
|
def failing_send(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 2:
|
||||||
|
raise Exception("Send failed")
|
||||||
|
|
||||||
|
mock_producer.send.side_effect = failing_send
|
||||||
|
|
||||||
|
publisher = Publisher(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add messages directly
|
||||||
|
await publisher.q.put(("id-1", {"data": 1}))
|
||||||
|
await publisher.q.put(("id-2", {"data": 2}))
|
||||||
|
|
||||||
|
with patch.object(publisher, 'run') as mock_run:
|
||||||
|
# Mock run that handles exceptions gracefully
|
||||||
|
async def mock_run_with_exceptions():
|
||||||
|
producer = mock_producer
|
||||||
|
|
||||||
|
while not publisher.q.empty():
|
||||||
|
try:
|
||||||
|
id, item = await publisher.q.get()
|
||||||
|
producer.send(item, {"id": id})
|
||||||
|
except Exception as e:
|
||||||
|
# Log exception but continue processing
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Always call flush and close
|
||||||
|
producer.flush()
|
||||||
|
producer.close()
|
||||||
|
|
||||||
|
mock_run.side_effect = mock_run_with_exceptions
|
||||||
|
|
||||||
|
await publisher.start()
|
||||||
|
await publisher.stop()
|
||||||
|
|
||||||
|
# Should have attempted to send both messages
|
||||||
|
assert mock_producer.send.call_count == 2
|
||||||
|
mock_producer.flush.assert_called_once()
|
||||||
|
mock_producer.close.assert_called_once()
|
||||||
315
tests/unit/test_base/test_subscriber_graceful_shutdown.py
Normal file
315
tests/unit/test_base/test_subscriber_graceful_shutdown.py
Normal file
|
|
@ -0,0 +1,315 @@
|
||||||
|
"""Unit tests for Subscriber graceful shutdown functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from trustgraph.base.subscriber import Subscriber
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pulsar_client():
|
||||||
|
"""Mock Pulsar client for testing."""
|
||||||
|
client = MagicMock()
|
||||||
|
consumer = MagicMock()
|
||||||
|
consumer.receive = MagicMock()
|
||||||
|
consumer.acknowledge = MagicMock()
|
||||||
|
consumer.negative_acknowledge = MagicMock()
|
||||||
|
consumer.pause_message_listener = MagicMock()
|
||||||
|
consumer.unsubscribe = MagicMock()
|
||||||
|
consumer.close = MagicMock()
|
||||||
|
client.subscribe.return_value = consumer
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def subscriber(mock_pulsar_client):
|
||||||
|
"""Create Subscriber instance for testing."""
|
||||||
|
return Subscriber(
|
||||||
|
client=mock_pulsar_client,
|
||||||
|
topic="test-topic",
|
||||||
|
subscription="test-subscription",
|
||||||
|
consumer_name="test-consumer",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=2.0,
|
||||||
|
backpressure_strategy="block"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_message(message_id="test-id", data=None):
|
||||||
|
"""Create a mock Pulsar message."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.properties.return_value = {"id": message_id}
|
||||||
|
msg.value.return_value = data or {"test": "data"}
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscriber_deferred_acknowledgment_success():
|
||||||
|
"""Verify Subscriber only acks on successful delivery."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
subscriber = Subscriber(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
subscription="test-subscription",
|
||||||
|
consumer_name="test-consumer",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
backpressure_strategy="block"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create queue for subscription
|
||||||
|
queue = await subscriber.subscribe("test-queue")
|
||||||
|
|
||||||
|
# Create mock message
|
||||||
|
msg = create_mock_message("msg-1", {"data": "test"})
|
||||||
|
|
||||||
|
# Process message
|
||||||
|
await subscriber._process_message(msg)
|
||||||
|
|
||||||
|
# Should acknowledge successful delivery
|
||||||
|
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||||
|
mock_consumer.negative_acknowledge.assert_not_called()
|
||||||
|
|
||||||
|
# Message should be in queue
|
||||||
|
assert not queue.empty()
|
||||||
|
received_msg = await queue.get()
|
||||||
|
assert received_msg == {"data": "test"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscriber_deferred_acknowledgment_failure():
|
||||||
|
"""Verify Subscriber negative acks on delivery failure."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
subscriber = Subscriber(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
subscription="test-subscription",
|
||||||
|
consumer_name="test-consumer",
|
||||||
|
schema=dict,
|
||||||
|
max_size=1, # Very small queue
|
||||||
|
backpressure_strategy="drop_new"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create queue and fill it
|
||||||
|
queue = await subscriber.subscribe("test-queue")
|
||||||
|
await queue.put({"existing": "data"})
|
||||||
|
|
||||||
|
# Create mock message - should be dropped
|
||||||
|
msg = create_mock_message("msg-1", {"data": "test"})
|
||||||
|
|
||||||
|
# Process message (should fail due to full queue + drop_new strategy)
|
||||||
|
await subscriber._process_message(msg)
|
||||||
|
|
||||||
|
# Should negative acknowledge failed delivery
|
||||||
|
mock_consumer.negative_acknowledge.assert_called_once_with(msg)
|
||||||
|
mock_consumer.acknowledge.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscriber_backpressure_strategies():
|
||||||
|
"""Test different backpressure strategies."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
# Test drop_oldest strategy
|
||||||
|
subscriber = Subscriber(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
subscription="test-subscription",
|
||||||
|
consumer_name="test-consumer",
|
||||||
|
schema=dict,
|
||||||
|
max_size=2,
|
||||||
|
backpressure_strategy="drop_oldest"
|
||||||
|
)
|
||||||
|
|
||||||
|
queue = await subscriber.subscribe("test-queue")
|
||||||
|
|
||||||
|
# Fill queue
|
||||||
|
await queue.put({"data": "old1"})
|
||||||
|
await queue.put({"data": "old2"})
|
||||||
|
|
||||||
|
# Add new message (should drop oldest)
|
||||||
|
msg = create_mock_message("msg-1", {"data": "new"})
|
||||||
|
await subscriber._process_message(msg)
|
||||||
|
|
||||||
|
# Should acknowledge delivery
|
||||||
|
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||||
|
|
||||||
|
# Queue should have new message (old one dropped)
|
||||||
|
messages = []
|
||||||
|
while not queue.empty():
|
||||||
|
messages.append(await queue.get())
|
||||||
|
|
||||||
|
# Should contain old2 and new (old1 was dropped)
|
||||||
|
assert len(messages) == 2
|
||||||
|
assert {"data": "new"} in messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscriber_graceful_shutdown():
|
||||||
|
"""Test Subscriber graceful shutdown with queue draining."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
subscriber = Subscriber(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
subscription="test-subscription",
|
||||||
|
consumer_name="test-consumer",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
await subscriber.start()
|
||||||
|
|
||||||
|
# Create subscription with messages
|
||||||
|
queue = await subscriber.subscribe("test-queue")
|
||||||
|
await queue.put({"data": "msg1"})
|
||||||
|
await queue.put({"data": "msg2"})
|
||||||
|
|
||||||
|
# Initial state
|
||||||
|
assert subscriber.running is True
|
||||||
|
assert subscriber.draining is False
|
||||||
|
|
||||||
|
# Start shutdown
|
||||||
|
stop_task = asyncio.create_task(subscriber.stop())
|
||||||
|
|
||||||
|
# Allow brief processing
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Should be in drain state
|
||||||
|
assert subscriber.running is False
|
||||||
|
assert subscriber.draining is True
|
||||||
|
|
||||||
|
# Should pause message listener
|
||||||
|
mock_consumer.pause_message_listener.assert_called_once()
|
||||||
|
|
||||||
|
# Complete shutdown
|
||||||
|
await stop_task
|
||||||
|
|
||||||
|
# Should have cleaned up
|
||||||
|
mock_consumer.unsubscribe.assert_called_once()
|
||||||
|
mock_consumer.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscriber_drain_timeout():
|
||||||
|
"""Test Subscriber respects drain timeout."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
subscriber = Subscriber(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
subscription="test-subscription",
|
||||||
|
consumer_name="test-consumer",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10,
|
||||||
|
drain_timeout=0.1 # Very short timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
await subscriber.start()
|
||||||
|
|
||||||
|
# Create subscription with many messages
|
||||||
|
queue = await subscriber.subscribe("test-queue")
|
||||||
|
for i in range(20):
|
||||||
|
await queue.put({"data": f"msg{i}"})
|
||||||
|
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
await subscriber.stop()
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
# Should timeout quickly
|
||||||
|
assert end_time - start_time < 1.0
|
||||||
|
|
||||||
|
# Queue should still have messages (drain timed out)
|
||||||
|
assert not queue.empty()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscriber_pending_acks_cleanup():
|
||||||
|
"""Test Subscriber cleans up pending acknowledgments on shutdown."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
subscriber = Subscriber(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
subscription="test-subscription",
|
||||||
|
consumer_name="test-consumer",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10
|
||||||
|
)
|
||||||
|
|
||||||
|
await subscriber.start()
|
||||||
|
|
||||||
|
# Add pending acknowledgments manually (simulating in-flight messages)
|
||||||
|
msg1 = create_mock_message("msg-1")
|
||||||
|
msg2 = create_mock_message("msg-2")
|
||||||
|
subscriber.pending_acks["ack-1"] = msg1
|
||||||
|
subscriber.pending_acks["ack-2"] = msg2
|
||||||
|
|
||||||
|
# Stop subscriber
|
||||||
|
await subscriber.stop()
|
||||||
|
|
||||||
|
# Should negative acknowledge pending messages
|
||||||
|
assert mock_consumer.negative_acknowledge.call_count == 2
|
||||||
|
mock_consumer.negative_acknowledge.assert_any_call(msg1)
|
||||||
|
mock_consumer.negative_acknowledge.assert_any_call(msg2)
|
||||||
|
|
||||||
|
# Pending acks should be cleared
|
||||||
|
assert len(subscriber.pending_acks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscriber_multiple_subscribers():
|
||||||
|
"""Test Subscriber with multiple concurrent subscribers."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_consumer = MagicMock()
|
||||||
|
mock_client.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
|
subscriber = Subscriber(
|
||||||
|
client=mock_client,
|
||||||
|
topic="test-topic",
|
||||||
|
subscription="test-subscription",
|
||||||
|
consumer_name="test-consumer",
|
||||||
|
schema=dict,
|
||||||
|
max_size=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create multiple subscriptions
|
||||||
|
queue1 = await subscriber.subscribe("queue-1")
|
||||||
|
queue2 = await subscriber.subscribe("queue-2")
|
||||||
|
queue_all = await subscriber.subscribe_all("queue-all")
|
||||||
|
|
||||||
|
# Process message
|
||||||
|
msg = create_mock_message("msg-1", {"data": "broadcast"})
|
||||||
|
await subscriber._process_message(msg)
|
||||||
|
|
||||||
|
# Should acknowledge (successful delivery to all queues)
|
||||||
|
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||||
|
|
||||||
|
# Message should be in specific queue (queue-1) and broadcast queue
|
||||||
|
assert not queue1.empty()
|
||||||
|
assert queue2.empty() # No message for queue-2
|
||||||
|
assert not queue_all.empty()
|
||||||
|
|
||||||
|
# Verify message content
|
||||||
|
msg1 = await queue1.get()
|
||||||
|
msg_all = await queue_all.get()
|
||||||
|
assert msg1 == {"data": "broadcast"}
|
||||||
|
assert msg_all == {"data": "broadcast"}
|
||||||
324
tests/unit/test_gateway/test_socket_graceful_shutdown.py
Normal file
324
tests/unit/test_gateway/test_socket_graceful_shutdown.py
Normal file
|
|
@ -0,0 +1,324 @@
|
||||||
|
"""Unit tests for SocketEndpoint graceful shutdown functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from aiohttp import web, WSMsgType
|
||||||
|
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
||||||
|
from trustgraph.gateway.running import Running
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_auth():
|
||||||
|
"""Mock authentication service."""
|
||||||
|
auth = MagicMock()
|
||||||
|
auth.permitted.return_value = True
|
||||||
|
return auth
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dispatcher_factory():
|
||||||
|
"""Mock dispatcher factory function."""
|
||||||
|
async def dispatcher_factory(ws, running, match_info):
|
||||||
|
dispatcher = AsyncMock()
|
||||||
|
dispatcher.run = AsyncMock()
|
||||||
|
dispatcher.receive = AsyncMock()
|
||||||
|
dispatcher.destroy = AsyncMock()
|
||||||
|
return dispatcher
|
||||||
|
|
||||||
|
return dispatcher_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def socket_endpoint(mock_auth, mock_dispatcher_factory):
|
||||||
|
"""Create SocketEndpoint for testing."""
|
||||||
|
return SocketEndpoint(
|
||||||
|
endpoint_path="/test-socket",
|
||||||
|
auth=mock_auth,
|
||||||
|
dispatcher=mock_dispatcher_factory
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_websocket():
|
||||||
|
"""Mock websocket response."""
|
||||||
|
ws = AsyncMock(spec=web.WebSocketResponse)
|
||||||
|
ws.prepare = AsyncMock()
|
||||||
|
ws.close = AsyncMock()
|
||||||
|
ws.closed = False
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_request():
|
||||||
|
"""Mock HTTP request."""
|
||||||
|
request = MagicMock()
|
||||||
|
request.query = {"token": "test-token"}
|
||||||
|
request.match_info = {}
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_listener_graceful_shutdown_on_close():
|
||||||
|
"""Test listener handles websocket close gracefully."""
|
||||||
|
socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock())
|
||||||
|
|
||||||
|
# Mock websocket that closes after one message
|
||||||
|
ws = AsyncMock()
|
||||||
|
|
||||||
|
# Create async iterator that yields one message then closes
|
||||||
|
async def mock_iterator():
|
||||||
|
# Yield normal message
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = WSMsgType.TEXT
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
# Yield close message
|
||||||
|
close_msg = MagicMock()
|
||||||
|
close_msg.type = WSMsgType.CLOSE
|
||||||
|
yield close_msg
|
||||||
|
|
||||||
|
ws.__aiter__ = mock_iterator
|
||||||
|
|
||||||
|
dispatcher = AsyncMock()
|
||||||
|
running = Running()
|
||||||
|
|
||||||
|
with patch('asyncio.sleep') as mock_sleep:
|
||||||
|
await socket_endpoint.listener(ws, dispatcher, running)
|
||||||
|
|
||||||
|
# Should have processed one message
|
||||||
|
dispatcher.receive.assert_called_once()
|
||||||
|
|
||||||
|
# Should have initiated graceful shutdown
|
||||||
|
assert running.get() is False
|
||||||
|
|
||||||
|
# Should have slept for grace period
|
||||||
|
mock_sleep.assert_called_once_with(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_normal_flow():
|
||||||
|
"""Test normal websocket handling flow."""
|
||||||
|
mock_auth = MagicMock()
|
||||||
|
mock_auth.permitted.return_value = True
|
||||||
|
|
||||||
|
dispatcher_created = False
|
||||||
|
async def mock_dispatcher_factory(ws, running, match_info):
|
||||||
|
nonlocal dispatcher_created
|
||||||
|
dispatcher_created = True
|
||||||
|
dispatcher = AsyncMock()
|
||||||
|
dispatcher.destroy = AsyncMock()
|
||||||
|
return dispatcher
|
||||||
|
|
||||||
|
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.query = {"token": "valid-token"}
|
||||||
|
request.match_info = {}
|
||||||
|
|
||||||
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.prepare = AsyncMock()
|
||||||
|
mock_ws.close = AsyncMock()
|
||||||
|
mock_ws.closed = False
|
||||||
|
mock_ws_class.return_value = mock_ws
|
||||||
|
|
||||||
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||||
|
# Mock task group context manager
|
||||||
|
mock_tg = AsyncMock()
|
||||||
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
||||||
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
|
# Should have created dispatcher
|
||||||
|
assert dispatcher_created is True
|
||||||
|
|
||||||
|
# Should return websocket
|
||||||
|
assert result == mock_ws
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_exception_group_cleanup():
|
||||||
|
"""Test exception group triggers dispatcher cleanup."""
|
||||||
|
mock_auth = MagicMock()
|
||||||
|
mock_auth.permitted.return_value = True
|
||||||
|
|
||||||
|
mock_dispatcher = AsyncMock()
|
||||||
|
mock_dispatcher.destroy = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_dispatcher_factory(ws, running, match_info):
|
||||||
|
return mock_dispatcher
|
||||||
|
|
||||||
|
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.query = {"token": "valid-token"}
|
||||||
|
request.match_info = {}
|
||||||
|
|
||||||
|
# Mock TaskGroup to raise ExceptionGroup
|
||||||
|
class TestException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
|
||||||
|
|
||||||
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.prepare = AsyncMock()
|
||||||
|
mock_ws.close = AsyncMock()
|
||||||
|
mock_ws.closed = False
|
||||||
|
mock_ws_class.return_value = mock_ws
|
||||||
|
|
||||||
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||||
|
mock_tg = AsyncMock()
|
||||||
|
mock_tg.__aenter__ = AsyncMock(side_effect=exception_group)
|
||||||
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
|
with patch('asyncio.wait_for') as mock_wait_for:
|
||||||
|
mock_wait_for.return_value = None
|
||||||
|
|
||||||
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
|
# Should have attempted graceful cleanup
|
||||||
|
mock_wait_for.assert_called_once()
|
||||||
|
|
||||||
|
# Should have called destroy in finally block
|
||||||
|
assert mock_dispatcher.destroy.call_count >= 1
|
||||||
|
|
||||||
|
# Should have closed websocket
|
||||||
|
mock_ws.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_dispatcher_cleanup_timeout():
|
||||||
|
"""Test dispatcher cleanup with timeout."""
|
||||||
|
mock_auth = MagicMock()
|
||||||
|
mock_auth.permitted.return_value = True
|
||||||
|
|
||||||
|
# Mock dispatcher that takes long to destroy
|
||||||
|
mock_dispatcher = AsyncMock()
|
||||||
|
mock_dispatcher.destroy = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_dispatcher_factory(ws, running, match_info):
|
||||||
|
return mock_dispatcher
|
||||||
|
|
||||||
|
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.query = {"token": "valid-token"}
|
||||||
|
request.match_info = {}
|
||||||
|
|
||||||
|
# Mock TaskGroup to raise exception
|
||||||
|
exception_group = ExceptionGroup("Test", [Exception("test")])
|
||||||
|
|
||||||
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.prepare = AsyncMock()
|
||||||
|
mock_ws.close = AsyncMock()
|
||||||
|
mock_ws.closed = False
|
||||||
|
mock_ws_class.return_value = mock_ws
|
||||||
|
|
||||||
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||||
|
mock_tg = AsyncMock()
|
||||||
|
mock_tg.__aenter__ = AsyncMock(side_effect=exception_group)
|
||||||
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
|
# Mock asyncio.wait_for to raise TimeoutError
|
||||||
|
with patch('asyncio.wait_for') as mock_wait_for:
|
||||||
|
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
|
||||||
|
|
||||||
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
|
# Should have attempted cleanup with timeout
|
||||||
|
mock_wait_for.assert_called_once_with(
|
||||||
|
mock_dispatcher.destroy(),
|
||||||
|
timeout=5.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still call destroy in finally block
|
||||||
|
assert mock_dispatcher.destroy.call_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_unauthorized_request():
|
||||||
|
"""Test handling of unauthorized requests."""
|
||||||
|
mock_auth = MagicMock()
|
||||||
|
mock_auth.permitted.return_value = False # Unauthorized
|
||||||
|
|
||||||
|
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.query = {"token": "invalid-token"}
|
||||||
|
|
||||||
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
|
# Should return HTTP 401
|
||||||
|
assert isinstance(result, web.HTTPUnauthorized)
|
||||||
|
|
||||||
|
# Should have checked permission
|
||||||
|
mock_auth.permitted.assert_called_once_with("invalid-token", "socket")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_missing_token():
|
||||||
|
"""Test handling of requests with missing token."""
|
||||||
|
mock_auth = MagicMock()
|
||||||
|
mock_auth.permitted.return_value = False
|
||||||
|
|
||||||
|
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.query = {} # No token
|
||||||
|
|
||||||
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
|
# Should return HTTP 401
|
||||||
|
assert isinstance(result, web.HTTPUnauthorized)
|
||||||
|
|
||||||
|
# Should have checked permission with empty token
|
||||||
|
mock_auth.permitted.assert_called_once_with("", "socket")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_websocket_already_closed():
|
||||||
|
"""Test handling when websocket is already closed."""
|
||||||
|
mock_auth = MagicMock()
|
||||||
|
mock_auth.permitted.return_value = True
|
||||||
|
|
||||||
|
mock_dispatcher = AsyncMock()
|
||||||
|
mock_dispatcher.destroy = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_dispatcher_factory(ws, running, match_info):
|
||||||
|
return mock_dispatcher
|
||||||
|
|
||||||
|
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.query = {"token": "valid-token"}
|
||||||
|
request.match_info = {}
|
||||||
|
|
||||||
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.prepare = AsyncMock()
|
||||||
|
mock_ws.close = AsyncMock()
|
||||||
|
mock_ws.closed = True # Already closed
|
||||||
|
mock_ws_class.return_value = mock_ws
|
||||||
|
|
||||||
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||||
|
mock_tg = AsyncMock()
|
||||||
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
||||||
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
|
# Should still have called destroy
|
||||||
|
mock_dispatcher.destroy.assert_called()
|
||||||
|
|
||||||
|
# Should not attempt to close already closed websocket
|
||||||
|
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True
|
||||||
Loading…
Add table
Add a link
Reference in a new issue