Release 1.4 -> master (#524)

Catch up
This commit is contained in:
cybermaggedon 2025-09-20 16:00:37 +01:00 committed by GitHub
parent a8e437fc7f
commit 6c7af8789d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
216 changed files with 31360 additions and 1611 deletions

View file

@ -63,6 +63,7 @@ class TestSocketEndpoint:
mock_ws = AsyncMock()
mock_ws.__aiter__ = lambda self: async_iter()
mock_ws.closed = False # Set closed attribute
mock_running = MagicMock()
# Call listener method
@ -92,6 +93,7 @@ class TestSocketEndpoint:
mock_ws = AsyncMock()
mock_ws.__aiter__ = lambda self: async_iter()
mock_ws.closed = False # Set closed attribute
mock_running = MagicMock()
# Call listener method
@ -121,6 +123,7 @@ class TestSocketEndpoint:
mock_ws = AsyncMock()
mock_ws.__aiter__ = lambda self: async_iter()
mock_ws.closed = False # Set closed attribute
mock_running = MagicMock()
# Call listener method

View file

@ -0,0 +1,546 @@
"""
Unit tests for objects import dispatcher.
Tests the business logic of objects import dispatcher
while mocking the Publisher and websocket components.
"""
import pytest
import json
import asyncio
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from aiohttp import web
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
from trustgraph.schema import Metadata, ExtractedObject
@pytest.fixture
def mock_pulsar_client():
"""Mock Pulsar client."""
client = Mock()
return client
@pytest.fixture
def mock_publisher():
"""Mock Publisher with async methods."""
publisher = Mock()
publisher.start = AsyncMock()
publisher.stop = AsyncMock()
publisher.send = AsyncMock()
return publisher
@pytest.fixture
def mock_running():
"""Mock Running state handler."""
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
"""Mock WebSocket connection."""
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_objects_message():
"""Sample objects message data."""
return {
"metadata": {
"id": "obj-123",
"metadata": [
{
"s": {"v": "obj-123", "e": False},
"p": {"v": "source", "e": False},
"o": {"v": "test", "e": False}
}
],
"user": "testuser",
"collection": "testcollection"
},
"schema_name": "person",
"values": [{
"name": "John Doe",
"age": "30",
"city": "New York"
}],
"confidence": 0.95,
"source_span": "John Doe, age 30, lives in New York"
}
@pytest.fixture
def minimal_objects_message():
"""Minimal required objects message data."""
return {
"metadata": {
"id": "obj-456",
"user": "testuser",
"collection": "testcollection"
},
"schema_name": "simple_schema",
"values": [{
"field1": "value1"
}]
}
class TestObjectsImportInitialization:
"""Test ObjectsImport initialization."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that ObjectsImport creates Publisher with correct parameters."""
mock_publisher_instance = Mock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-objects-queue"
)
# Verify Publisher was created with correct parameters
mock_publisher_class.assert_called_once_with(
mock_pulsar_client,
topic="test-objects-queue",
schema=ExtractedObject
)
# Verify instance variables are set correctly
assert objects_import.ws == mock_websocket
assert objects_import.running == mock_running
assert objects_import.publisher == mock_publisher_instance
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
def test_init_stores_references_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that ObjectsImport stores all required references."""
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="objects-queue"
)
assert objects_import.ws is mock_websocket
assert objects_import.running is mock_running
class TestObjectsImportLifecycle:
"""Test ObjectsImport lifecycle methods."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that start() calls publisher.start()."""
mock_publisher_instance = Mock()
mock_publisher_instance.start = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
await objects_import.start()
mock_publisher_instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that destroy() properly stops publisher and closes websocket."""
mock_publisher_instance = Mock()
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
await objects_import.destroy()
# Verify sequence of operations
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_pulsar_client, mock_running):
"""Test that destroy() handles None websocket gracefully."""
mock_publisher_instance = Mock()
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=None, # None websocket
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Should not raise exception
await objects_import.destroy()
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
class TestObjectsImportMessageProcessing:
"""Test ObjectsImport message processing."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() processes complete message correctly."""
mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Create mock message
mock_msg = Mock()
mock_msg.json.return_value = sample_objects_message
await objects_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
# Get the call arguments
call_args = mock_publisher_instance.send.call_args
assert call_args[0][0] is None # First argument should be None
# Check the ExtractedObject that was sent
sent_object = call_args[0][1]
assert isinstance(sent_object, ExtractedObject)
assert sent_object.schema_name == "person"
assert sent_object.values[0]["name"] == "John Doe"
assert sent_object.values[0]["age"] == "30"
assert sent_object.confidence == 0.95
assert sent_object.source_span == "John Doe, age 30, lives in New York"
# Check metadata
assert sent_object.metadata.id == "obj-123"
assert sent_object.metadata.user == "testuser"
assert sent_object.metadata.collection == "testcollection"
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, minimal_objects_message):
"""Test that receive() handles message with minimal required fields."""
mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Create mock message
mock_msg = Mock()
mock_msg.json.return_value = minimal_objects_message
await objects_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
# Get the sent object
sent_object = mock_publisher_instance.send.call_args[0][1]
assert isinstance(sent_object, ExtractedObject)
assert sent_object.schema_name == "simple_schema"
assert sent_object.values[0]["field1"] == "value1"
assert sent_object.confidence == 1.0 # Default value
assert sent_object.source_span == "" # Default value
assert len(sent_object.metadata.metadata) == 0 # Default empty list
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_uses_default_values(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that receive() uses appropriate default values for optional fields."""
mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Message without optional fields
message_data = {
"metadata": {
"id": "obj-789",
"user": "testuser",
"collection": "testcollection"
},
"schema_name": "test_schema",
"values": [{"key": "value"}]
# No confidence or source_span
}
mock_msg = Mock()
mock_msg.json.return_value = message_data
await objects_import.receive(mock_msg)
# Get the sent object and verify defaults
sent_object = mock_publisher_instance.send.call_args[0][1]
assert sent_object.confidence == 1.0
assert sent_object.source_span == ""
class TestObjectsImportRunMethod:
"""Test ObjectsImport run method."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that run() loops while running.get() returns True."""
mock_sleep.return_value = None
mock_publisher_class.return_value = Mock()
# Set up running state to return True twice, then False
mock_running.get.side_effect = [True, True, False]
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
await objects_import.run()
# Verify sleep was called twice (for the two True iterations)
assert mock_sleep.call_count == 2
mock_sleep.assert_called_with(0.5)
# Verify websocket was closed
mock_websocket.close.assert_called_once()
# Verify websocket was set to None
assert objects_import.ws is None
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_running):
"""Test that run() handles None websocket gracefully."""
mock_sleep.return_value = None
mock_publisher_class.return_value = Mock()
mock_running.get.return_value = False # Exit immediately
objects_import = ObjectsImport(
ws=None, # None websocket
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Should not raise exception
await objects_import.run()
# Verify websocket remains None
assert objects_import.ws is None
class TestObjectsImportBatchProcessing:
"""Test ObjectsImport batch processing functionality."""
@pytest.fixture
def batch_objects_message(self):
"""Sample batch objects message data."""
return {
"metadata": {
"id": "batch-001",
"metadata": [
{
"s": {"v": "batch-001", "e": False},
"p": {"v": "source", "e": False},
"o": {"v": "test", "e": False}
}
],
"user": "testuser",
"collection": "testcollection"
},
"schema_name": "person",
"values": [
{
"name": "John Doe",
"age": "30",
"city": "New York"
},
{
"name": "Jane Smith",
"age": "25",
"city": "Boston"
},
{
"name": "Bob Johnson",
"age": "45",
"city": "Chicago"
}
],
"confidence": 0.85,
"source_span": "Multiple people found in document"
}
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message):
"""Test that receive() processes batch message correctly."""
mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Create mock message
mock_msg = Mock()
mock_msg.json.return_value = batch_objects_message
await objects_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
# Get the call arguments
call_args = mock_publisher_instance.send.call_args
assert call_args[0][0] is None # First argument should be None
# Check the ExtractedObject that was sent
sent_object = call_args[0][1]
assert isinstance(sent_object, ExtractedObject)
assert sent_object.schema_name == "person"
# Check that all batch values are present
assert len(sent_object.values) == 3
assert sent_object.values[0]["name"] == "John Doe"
assert sent_object.values[0]["age"] == "30"
assert sent_object.values[0]["city"] == "New York"
assert sent_object.values[1]["name"] == "Jane Smith"
assert sent_object.values[1]["age"] == "25"
assert sent_object.values[1]["city"] == "Boston"
assert sent_object.values[2]["name"] == "Bob Johnson"
assert sent_object.values[2]["age"] == "45"
assert sent_object.values[2]["city"] == "Chicago"
assert sent_object.confidence == 0.85
assert sent_object.source_span == "Multiple people found in document"
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that receive() handles empty batch correctly."""
mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Message with empty values array
empty_batch_message = {
"metadata": {
"id": "empty-batch-001",
"user": "testuser",
"collection": "testcollection"
},
"schema_name": "empty_schema",
"values": []
}
mock_msg = Mock()
mock_msg.json.return_value = empty_batch_message
await objects_import.receive(mock_msg)
# Should still send the message
mock_publisher_instance.send.assert_called_once()
sent_object = mock_publisher_instance.send.call_args[0][1]
assert len(sent_object.values) == 0
class TestObjectsImportErrorHandling:
"""Test error handling in ObjectsImport."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() propagates publisher send errors."""
mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
mock_msg = Mock()
mock_msg.json.return_value = sample_objects_message
with pytest.raises(Exception, match="Publisher error"):
await objects_import.receive(mock_msg)
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that receive() handles malformed JSON appropriately."""
mock_publisher_class.return_value = Mock()
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
mock_msg = Mock()
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
with pytest.raises(json.JSONDecodeError):
await objects_import.receive(mock_msg)

View file

@ -0,0 +1,326 @@
"""Unit tests for SocketEndpoint graceful shutdown functionality."""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp import web, WSMsgType
from trustgraph.gateway.endpoint.socket import SocketEndpoint
from trustgraph.gateway.running import Running
@pytest.fixture
def mock_auth():
"""Mock authentication service."""
auth = MagicMock()
auth.permitted.return_value = True
return auth
@pytest.fixture
def mock_dispatcher_factory():
"""Mock dispatcher factory function."""
async def dispatcher_factory(ws, running, match_info):
dispatcher = AsyncMock()
dispatcher.run = AsyncMock()
dispatcher.receive = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
return dispatcher_factory
@pytest.fixture
def socket_endpoint(mock_auth, mock_dispatcher_factory):
"""Create SocketEndpoint for testing."""
return SocketEndpoint(
endpoint_path="/test-socket",
auth=mock_auth,
dispatcher=mock_dispatcher_factory
)
@pytest.fixture
def mock_websocket():
"""Mock websocket response."""
ws = AsyncMock(spec=web.WebSocketResponse)
ws.prepare = AsyncMock()
ws.close = AsyncMock()
ws.closed = False
return ws
@pytest.fixture
def mock_request():
"""Mock HTTP request."""
request = MagicMock()
request.query = {"token": "test-token"}
request.match_info = {}
return request
@pytest.mark.asyncio
async def test_listener_graceful_shutdown_on_close():
"""Test listener handles websocket close gracefully."""
socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock())
# Mock websocket that closes after one message
ws = AsyncMock()
# Create async iterator that yields one message then closes
async def mock_iterator(self):
# Yield normal message
msg = MagicMock()
msg.type = WSMsgType.TEXT
yield msg
# Yield close message
close_msg = MagicMock()
close_msg.type = WSMsgType.CLOSE
yield close_msg
# Set the async iterator method
ws.__aiter__ = mock_iterator
dispatcher = AsyncMock()
running = Running()
with patch('asyncio.sleep') as mock_sleep:
await socket_endpoint.listener(ws, dispatcher, running)
# Should have processed one message
dispatcher.receive.assert_called_once()
# Should have initiated graceful shutdown
assert running.get() is False
# Should have slept for grace period
mock_sleep.assert_called_once_with(1.0)
@pytest.mark.asyncio
async def test_handle_normal_flow():
"""Test normal websocket handling flow."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
dispatcher_created = False
async def mock_dispatcher_factory(ws, running, match_info):
nonlocal dispatcher_created
dispatcher_created = True
dispatcher = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
# Mock task group context manager
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None)
mock_tg.create_task = MagicMock(return_value=AsyncMock())
mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request)
# Should have created dispatcher
assert dispatcher_created is True
# Should return websocket
assert result == mock_ws
@pytest.mark.asyncio
async def test_handle_exception_group_cleanup():
"""Test exception group triggers dispatcher cleanup."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
# Mock TaskGroup to raise ExceptionGroup
class TestException(Exception):
pass
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
mock_tg.create_task = MagicMock(side_effect=TestException("test"))
mock_task_group.return_value = mock_tg
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
mock_wait_for.return_value = None
result = await socket_endpoint.handle(request)
# Should have attempted graceful cleanup
mock_wait_for.assert_called_once()
# Should have called destroy in finally block
assert mock_dispatcher.destroy.call_count >= 1
# Should have closed websocket
mock_ws.close.assert_called()
@pytest.mark.asyncio
async def test_handle_dispatcher_cleanup_timeout():
"""Test dispatcher cleanup with timeout."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
# Mock dispatcher that takes long to destroy
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
# Mock TaskGroup to raise exception
exception_group = ExceptionGroup("Test", [Exception("test")])
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
mock_tg.create_task = MagicMock(side_effect=Exception("test"))
mock_task_group.return_value = mock_tg
# Mock asyncio.wait_for to raise TimeoutError
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
result = await socket_endpoint.handle(request)
# Should have attempted cleanup with timeout
mock_wait_for.assert_called_once()
# Check that timeout was passed correctly
assert mock_wait_for.call_args[1]['timeout'] == 5.0
# Should still call destroy in finally block
assert mock_dispatcher.destroy.call_count >= 1
@pytest.mark.asyncio
async def test_handle_unauthorized_request():
"""Test handling of unauthorized requests."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = False # Unauthorized
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
request = MagicMock()
request.query = {"token": "invalid-token"}
result = await socket_endpoint.handle(request)
# Should return HTTP 401
assert isinstance(result, web.HTTPUnauthorized)
# Should have checked permission
mock_auth.permitted.assert_called_once_with("invalid-token", "socket")
@pytest.mark.asyncio
async def test_handle_missing_token():
"""Test handling of requests with missing token."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = False
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
request = MagicMock()
request.query = {} # No token
result = await socket_endpoint.handle(request)
# Should return HTTP 401
assert isinstance(result, web.HTTPUnauthorized)
# Should have checked permission with empty token
mock_auth.permitted.assert_called_once_with("", "socket")
@pytest.mark.asyncio
async def test_handle_websocket_already_closed():
"""Test handling when websocket is already closed."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = True # Already closed
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None)
mock_tg.create_task = MagicMock(return_value=AsyncMock())
mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request)
# Should still have called destroy
mock_dispatcher.destroy.assert_called()
# Should not attempt to close already closed websocket
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True