mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 18:36:22 +02:00
parent
a8e437fc7f
commit
6c7af8789d
216 changed files with 31360 additions and 1611 deletions
|
|
@ -63,6 +63,7 @@ class TestSocketEndpoint:
|
|||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_ws.closed = False # Set closed attribute
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
|
|
@ -92,6 +93,7 @@ class TestSocketEndpoint:
|
|||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_ws.closed = False # Set closed attribute
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
|
|
@ -121,6 +123,7 @@ class TestSocketEndpoint:
|
|||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_ws.closed = False # Set closed attribute
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
|
|
|
|||
546
tests/unit/test_gateway/test_objects_import_dispatcher.py
Normal file
546
tests/unit/test_gateway/test_objects_import_dispatcher.py
Normal file
|
|
@ -0,0 +1,546 @@
|
|||
"""
|
||||
Unit tests for objects import dispatcher.
|
||||
|
||||
Tests the business logic of objects import dispatcher
|
||||
while mocking the Publisher and websocket components.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
|
||||
from trustgraph.schema import Metadata, ExtractedObject
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_client():
|
||||
"""Mock Pulsar client."""
|
||||
client = Mock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_publisher():
|
||||
"""Mock Publisher with async methods."""
|
||||
publisher = Mock()
|
||||
publisher.start = AsyncMock()
|
||||
publisher.stop = AsyncMock()
|
||||
publisher.send = AsyncMock()
|
||||
return publisher
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
"""Mock Running state handler."""
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Mock WebSocket connection."""
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_objects_message():
|
||||
"""Sample objects message data."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "obj-123",
|
||||
"metadata": [
|
||||
{
|
||||
"s": {"v": "obj-123", "e": False},
|
||||
"p": {"v": "source", "e": False},
|
||||
"o": {"v": "test", "e": False}
|
||||
}
|
||||
],
|
||||
"user": "testuser",
|
||||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "person",
|
||||
"values": [{
|
||||
"name": "John Doe",
|
||||
"age": "30",
|
||||
"city": "New York"
|
||||
}],
|
||||
"confidence": 0.95,
|
||||
"source_span": "John Doe, age 30, lives in New York"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def minimal_objects_message():
|
||||
"""Minimal required objects message data."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "obj-456",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "simple_schema",
|
||||
"values": [{
|
||||
"field1": "value1"
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
class TestObjectsImportInitialization:
|
||||
"""Test ObjectsImport initialization."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport creates Publisher with correct parameters."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-objects-queue"
|
||||
)
|
||||
|
||||
# Verify Publisher was created with correct parameters
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_pulsar_client,
|
||||
topic="test-objects-queue",
|
||||
schema=ExtractedObject
|
||||
)
|
||||
|
||||
# Verify instance variables are set correctly
|
||||
assert objects_import.ws == mock_websocket
|
||||
assert objects_import.running == mock_running
|
||||
assert objects_import.publisher == mock_publisher_instance
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
def test_init_stores_references_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport stores all required references."""
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="objects-queue"
|
||||
)
|
||||
|
||||
assert objects_import.ws is mock_websocket
|
||||
assert objects_import.running is mock_running
|
||||
|
||||
|
||||
class TestObjectsImportLifecycle:
|
||||
"""Test ObjectsImport lifecycle methods."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that start() calls publisher.start()."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.start()
|
||||
|
||||
mock_publisher_instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that destroy() properly stops publisher and closes websocket."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.destroy()
|
||||
|
||||
# Verify sequence of operations
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_publisher_instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_pulsar_client, mock_running):
|
||||
"""Test that destroy() handles None websocket gracefully."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
await objects_import.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_publisher_instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestObjectsImportMessageProcessing:
|
||||
"""Test ObjectsImport message processing."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() processes complete message correctly."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
||||
# Get the call arguments
|
||||
call_args = mock_publisher_instance.send.call_args
|
||||
assert call_args[0][0] is None # First argument should be None
|
||||
|
||||
# Check the ExtractedObject that was sent
|
||||
sent_object = call_args[0][1]
|
||||
assert isinstance(sent_object, ExtractedObject)
|
||||
assert sent_object.schema_name == "person"
|
||||
assert sent_object.values[0]["name"] == "John Doe"
|
||||
assert sent_object.values[0]["age"] == "30"
|
||||
assert sent_object.confidence == 0.95
|
||||
assert sent_object.source_span == "John Doe, age 30, lives in New York"
|
||||
|
||||
# Check metadata
|
||||
assert sent_object.metadata.id == "obj-123"
|
||||
assert sent_object.metadata.user == "testuser"
|
||||
assert sent_object.metadata.collection == "testcollection"
|
||||
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, minimal_objects_message):
|
||||
"""Test that receive() handles message with minimal required fields."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = minimal_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
||||
# Get the sent object
|
||||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||
assert isinstance(sent_object, ExtractedObject)
|
||||
assert sent_object.schema_name == "simple_schema"
|
||||
assert sent_object.values[0]["field1"] == "value1"
|
||||
assert sent_object.confidence == 1.0 # Default value
|
||||
assert sent_object.source_span == "" # Default value
|
||||
assert len(sent_object.metadata.metadata) == 0 # Default empty list
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_uses_default_values(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that receive() uses appropriate default values for optional fields."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Message without optional fields
|
||||
message_data = {
|
||||
"metadata": {
|
||||
"id": "obj-789",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "test_schema",
|
||||
"values": [{"key": "value"}]
|
||||
# No confidence or source_span
|
||||
}
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = message_data
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
|
||||
# Get the sent object and verify defaults
|
||||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||
assert sent_object.confidence == 1.0
|
||||
assert sent_object.source_span == ""
|
||||
|
||||
|
||||
class TestObjectsImportRunMethod:
|
||||
"""Test ObjectsImport run method."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that run() loops while running.get() returns True."""
|
||||
mock_sleep.return_value = None
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
||||
# Set up running state to return True twice, then False
|
||||
mock_running.get.side_effect = [True, True, False]
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.run()
|
||||
|
||||
# Verify sleep was called twice (for the two True iterations)
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_called_with(0.5)
|
||||
|
||||
# Verify websocket was closed
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
# Verify websocket was set to None
|
||||
assert objects_import.ws is None
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_running):
|
||||
"""Test that run() handles None websocket gracefully."""
|
||||
mock_sleep.return_value = None
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
||||
mock_running.get.return_value = False # Exit immediately
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
await objects_import.run()
|
||||
|
||||
# Verify websocket remains None
|
||||
assert objects_import.ws is None
|
||||
|
||||
|
||||
class TestObjectsImportBatchProcessing:
|
||||
"""Test ObjectsImport batch processing functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def batch_objects_message(self):
|
||||
"""Sample batch objects message data."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "batch-001",
|
||||
"metadata": [
|
||||
{
|
||||
"s": {"v": "batch-001", "e": False},
|
||||
"p": {"v": "source", "e": False},
|
||||
"o": {"v": "test", "e": False}
|
||||
}
|
||||
],
|
||||
"user": "testuser",
|
||||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "person",
|
||||
"values": [
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": "30",
|
||||
"city": "New York"
|
||||
},
|
||||
{
|
||||
"name": "Jane Smith",
|
||||
"age": "25",
|
||||
"city": "Boston"
|
||||
},
|
||||
{
|
||||
"name": "Bob Johnson",
|
||||
"age": "45",
|
||||
"city": "Chicago"
|
||||
}
|
||||
],
|
||||
"confidence": 0.85,
|
||||
"source_span": "Multiple people found in document"
|
||||
}
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message):
|
||||
"""Test that receive() processes batch message correctly."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = batch_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
||||
# Get the call arguments
|
||||
call_args = mock_publisher_instance.send.call_args
|
||||
assert call_args[0][0] is None # First argument should be None
|
||||
|
||||
# Check the ExtractedObject that was sent
|
||||
sent_object = call_args[0][1]
|
||||
assert isinstance(sent_object, ExtractedObject)
|
||||
assert sent_object.schema_name == "person"
|
||||
|
||||
# Check that all batch values are present
|
||||
assert len(sent_object.values) == 3
|
||||
assert sent_object.values[0]["name"] == "John Doe"
|
||||
assert sent_object.values[0]["age"] == "30"
|
||||
assert sent_object.values[0]["city"] == "New York"
|
||||
|
||||
assert sent_object.values[1]["name"] == "Jane Smith"
|
||||
assert sent_object.values[1]["age"] == "25"
|
||||
assert sent_object.values[1]["city"] == "Boston"
|
||||
|
||||
assert sent_object.values[2]["name"] == "Bob Johnson"
|
||||
assert sent_object.values[2]["age"] == "45"
|
||||
assert sent_object.values[2]["city"] == "Chicago"
|
||||
|
||||
assert sent_object.confidence == 0.85
|
||||
assert sent_object.source_span == "Multiple people found in document"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that receive() handles empty batch correctly."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Message with empty values array
|
||||
empty_batch_message = {
|
||||
"metadata": {
|
||||
"id": "empty-batch-001",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "empty_schema",
|
||||
"values": []
|
||||
}
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_batch_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
|
||||
# Should still send the message
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||
assert len(sent_object.values) == 0
|
||||
|
||||
|
||||
class TestObjectsImportErrorHandling:
|
||||
"""Test error handling in ObjectsImport."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() propagates publisher send errors."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_objects_message
|
||||
|
||||
with pytest.raises(Exception, match="Publisher error"):
|
||||
await objects_import.receive(mock_msg)
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that receive() handles malformed JSON appropriately."""
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
await objects_import.receive(mock_msg)
|
||||
326
tests/unit/test_gateway/test_socket_graceful_shutdown.py
Normal file
326
tests/unit/test_gateway/test_socket_graceful_shutdown.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
"""Unit tests for SocketEndpoint graceful shutdown functionality."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from aiohttp import web, WSMsgType
|
||||
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
||||
from trustgraph.gateway.running import Running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth():
|
||||
"""Mock authentication service."""
|
||||
auth = MagicMock()
|
||||
auth.permitted.return_value = True
|
||||
return auth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dispatcher_factory():
|
||||
"""Mock dispatcher factory function."""
|
||||
async def dispatcher_factory(ws, running, match_info):
|
||||
dispatcher = AsyncMock()
|
||||
dispatcher.run = AsyncMock()
|
||||
dispatcher.receive = AsyncMock()
|
||||
dispatcher.destroy = AsyncMock()
|
||||
return dispatcher
|
||||
|
||||
return dispatcher_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket_endpoint(mock_auth, mock_dispatcher_factory):
|
||||
"""Create SocketEndpoint for testing."""
|
||||
return SocketEndpoint(
|
||||
endpoint_path="/test-socket",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher_factory
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Mock websocket response."""
|
||||
ws = AsyncMock(spec=web.WebSocketResponse)
|
||||
ws.prepare = AsyncMock()
|
||||
ws.close = AsyncMock()
|
||||
ws.closed = False
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Mock HTTP request."""
|
||||
request = MagicMock()
|
||||
request.query = {"token": "test-token"}
|
||||
request.match_info = {}
|
||||
return request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_graceful_shutdown_on_close():
|
||||
"""Test listener handles websocket close gracefully."""
|
||||
socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock())
|
||||
|
||||
# Mock websocket that closes after one message
|
||||
ws = AsyncMock()
|
||||
|
||||
# Create async iterator that yields one message then closes
|
||||
async def mock_iterator(self):
|
||||
# Yield normal message
|
||||
msg = MagicMock()
|
||||
msg.type = WSMsgType.TEXT
|
||||
yield msg
|
||||
|
||||
# Yield close message
|
||||
close_msg = MagicMock()
|
||||
close_msg.type = WSMsgType.CLOSE
|
||||
yield close_msg
|
||||
|
||||
# Set the async iterator method
|
||||
ws.__aiter__ = mock_iterator
|
||||
|
||||
dispatcher = AsyncMock()
|
||||
running = Running()
|
||||
|
||||
with patch('asyncio.sleep') as mock_sleep:
|
||||
await socket_endpoint.listener(ws, dispatcher, running)
|
||||
|
||||
# Should have processed one message
|
||||
dispatcher.receive.assert_called_once()
|
||||
|
||||
# Should have initiated graceful shutdown
|
||||
assert running.get() is False
|
||||
|
||||
# Should have slept for grace period
|
||||
mock_sleep.assert_called_once_with(1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_normal_flow():
|
||||
"""Test normal websocket handling flow."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
dispatcher_created = False
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
nonlocal dispatcher_created
|
||||
dispatcher_created = True
|
||||
dispatcher = AsyncMock()
|
||||
dispatcher.destroy = AsyncMock()
|
||||
return dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
# Mock task group context manager
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should have created dispatcher
|
||||
assert dispatcher_created is True
|
||||
|
||||
# Should return websocket
|
||||
assert result == mock_ws
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_exception_group_cleanup():
|
||||
"""Test exception group triggers dispatcher cleanup."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
# Mock TaskGroup to raise ExceptionGroup
|
||||
class TestException(Exception):
|
||||
pass
|
||||
|
||||
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||
mock_tg.create_task = MagicMock(side_effect=TestException("test"))
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
||||
mock_wait_for.return_value = None
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should have attempted graceful cleanup
|
||||
mock_wait_for.assert_called_once()
|
||||
|
||||
# Should have called destroy in finally block
|
||||
assert mock_dispatcher.destroy.call_count >= 1
|
||||
|
||||
# Should have closed websocket
|
||||
mock_ws.close.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_dispatcher_cleanup_timeout():
|
||||
"""Test dispatcher cleanup with timeout."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
# Mock dispatcher that takes long to destroy
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
# Mock TaskGroup to raise exception
|
||||
exception_group = ExceptionGroup("Test", [Exception("test")])
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||
mock_tg.create_task = MagicMock(side_effect=Exception("test"))
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
# Mock asyncio.wait_for to raise TimeoutError
|
||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
||||
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should have attempted cleanup with timeout
|
||||
mock_wait_for.assert_called_once()
|
||||
# Check that timeout was passed correctly
|
||||
assert mock_wait_for.call_args[1]['timeout'] == 5.0
|
||||
|
||||
# Should still call destroy in finally block
|
||||
assert mock_dispatcher.destroy.call_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unauthorized_request():
|
||||
"""Test handling of unauthorized requests."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = False # Unauthorized
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "invalid-token"}
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should return HTTP 401
|
||||
assert isinstance(result, web.HTTPUnauthorized)
|
||||
|
||||
# Should have checked permission
|
||||
mock_auth.permitted.assert_called_once_with("invalid-token", "socket")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_missing_token():
|
||||
"""Test handling of requests with missing token."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = False
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {} # No token
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should return HTTP 401
|
||||
assert isinstance(result, web.HTTPUnauthorized)
|
||||
|
||||
# Should have checked permission with empty token
|
||||
mock_auth.permitted.assert_called_once_with("", "socket")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_websocket_already_closed():
|
||||
"""Test handling when websocket is already closed."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = True # Already closed
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should still have called destroy
|
||||
mock_dispatcher.destroy.assert_called()
|
||||
|
||||
# Should not attempt to close already closed websocket
|
||||
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True
|
||||
Loading…
Add table
Add a link
Reference in a new issue