"""Unit tests for SocketEndpoint graceful shutdown functionality. These tests exercise SocketEndpoint in its handshake-auth configuration (``in_band_auth=False``) — the mode used in production for the flow import/export streaming endpoints. The mux socket at ``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the handshake always accepts and authentication runs on the first WebSocket frame; that path is covered by the Mux tests. Every endpoint constructor here passes an explicit capability — no permissive default is relied upon. """ import pytest import asyncio from unittest.mock import AsyncMock, MagicMock, patch from aiohttp import web, WSMsgType from trustgraph.gateway.endpoint.socket import SocketEndpoint from trustgraph.gateway.running import Running from trustgraph.gateway.auth import Identity # Representative capability used across these tests — corresponds to # the flow-import streaming endpoint pattern that uses this class. TEST_CAP = "graph:write" def _valid_identity(): return Identity( handle="test-user", workspace="default", principal_id="test-user", source="api-key", ) @pytest.fixture def mock_auth(): """Mock IAM-backed authenticator. Successful by default — ``authenticate`` returns a valid identity and ``authorise`` allows everything. Tests that need the failure paths override the relevant attribute locally.""" auth = MagicMock() auth.authenticate = AsyncMock(return_value=_valid_identity()) auth.authorise = AsyncMock(return_value=None) return auth @pytest.fixture def mock_dispatcher_factory(): """Mock dispatcher factory function.""" async def dispatcher_factory(ws, running, match_info): dispatcher = AsyncMock() dispatcher.run = AsyncMock() dispatcher.receive = AsyncMock() dispatcher.destroy = AsyncMock() return dispatcher return dispatcher_factory @pytest.fixture def socket_endpoint(mock_auth, mock_dispatcher_factory): """Create SocketEndpoint for testing.""" return SocketEndpoint( endpoint_path="/test-socket", auth=mock_auth, dispatcher=mock_dispatcher_factory, capability=TEST_CAP, ) @pytest.fixture def mock_websocket(): """Mock websocket response.""" ws = AsyncMock(spec=web.WebSocketResponse) ws.prepare = AsyncMock() ws.close = AsyncMock() ws.closed = False return ws @pytest.fixture def mock_request(): """Mock HTTP request.""" request = MagicMock() request.query = {"token": "test-token"} request.match_info = {} return request @pytest.mark.asyncio async def test_listener_graceful_shutdown_on_close(): """Test listener handles websocket close gracefully.""" socket_endpoint = SocketEndpoint( "/test", MagicMock(), AsyncMock(), capability=TEST_CAP, ) # Mock websocket that closes after one message ws = AsyncMock() # Create async iterator that yields one message then closes async def mock_iterator(self): # Yield normal message msg = MagicMock() msg.type = WSMsgType.TEXT yield msg # Yield close message close_msg = MagicMock() close_msg.type = WSMsgType.CLOSE yield close_msg # Set the async iterator method ws.__aiter__ = mock_iterator dispatcher = AsyncMock() running = Running() with patch('asyncio.sleep') as mock_sleep: await socket_endpoint.listener(ws, dispatcher, running) # Should have processed one message dispatcher.receive.assert_called_once() # Should have initiated graceful shutdown assert running.get() is False # Should have slept for grace period mock_sleep.assert_called_once_with(1.0) @pytest.mark.asyncio async def test_handle_normal_flow(): """Valid bearer → handshake accepted, dispatcher created.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) mock_auth.authorise = AsyncMock(return_value=None) dispatcher_created = False async def mock_dispatcher_factory(ws, running, match_info): nonlocal dispatcher_created dispatcher_created = True dispatcher = AsyncMock() dispatcher.destroy = AsyncMock() return dispatcher socket_endpoint = SocketEndpoint( "/test", mock_auth, mock_dispatcher_factory, capability=TEST_CAP, ) request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws with patch('asyncio.TaskGroup') as mock_task_group: # Mock task group context manager mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(return_value=None) # Create proper mock tasks that look like asyncio.Task objects def create_task_mock(coro): # Consume the coroutine to avoid "was never awaited" warning coro.close() task = AsyncMock() task.done = MagicMock(return_value=True) task.cancelled = MagicMock(return_value=False) return task mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg result = await socket_endpoint.handle(request) # Should have created dispatcher assert dispatcher_created is True # Should return websocket assert result == mock_ws @pytest.mark.asyncio async def test_handle_exception_group_cleanup(): """Test exception group triggers dispatcher cleanup.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) mock_auth.authorise = AsyncMock(return_value=None) mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher socket_endpoint = SocketEndpoint( "/test", mock_auth, mock_dispatcher_factory, capability=TEST_CAP, ) request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} # Mock TaskGroup to raise ExceptionGroup class TestException(Exception): pass exception_group = ExceptionGroup("Test exceptions", [TestException("test")]) with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(side_effect=exception_group) # Create proper mock tasks that look like asyncio.Task objects def create_task_mock(coro): # Consume the coroutine to avoid "was never awaited" warning coro.close() task = AsyncMock() task.done = MagicMock(return_value=True) task.cancelled = MagicMock(return_value=False) return task mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for: # Make wait_for consume the coroutine passed to it async def wait_for_side_effect(coro, timeout=None): coro.close() # Consume the coroutine return None mock_wait_for.side_effect = wait_for_side_effect result = await socket_endpoint.handle(request) # Should have attempted graceful cleanup mock_wait_for.assert_called_once() # Should have called destroy in finally block assert mock_dispatcher.destroy.call_count >= 1 # Should have closed websocket mock_ws.close.assert_called() @pytest.mark.asyncio async def test_handle_dispatcher_cleanup_timeout(): """Test dispatcher cleanup with timeout.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) mock_auth.authorise = AsyncMock(return_value=None) # Mock dispatcher that takes long to destroy mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher socket_endpoint = SocketEndpoint( "/test", mock_auth, mock_dispatcher_factory, capability=TEST_CAP, ) request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} # Mock TaskGroup to raise exception exception_group = ExceptionGroup("Test", [Exception("test")]) with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(side_effect=exception_group) # Create proper mock tasks that look like asyncio.Task objects def create_task_mock(coro): # Consume the coroutine to avoid "was never awaited" warning coro.close() task = AsyncMock() task.done = MagicMock(return_value=True) task.cancelled = MagicMock(return_value=False) return task mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg # Mock asyncio.wait_for to raise TimeoutError with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for: # Make wait_for consume the coroutine before raising async def wait_for_timeout(coro, timeout=None): coro.close() # Consume the coroutine raise asyncio.TimeoutError("Cleanup timeout") mock_wait_for.side_effect = wait_for_timeout result = await socket_endpoint.handle(request) # Should have attempted cleanup with timeout mock_wait_for.assert_called_once() # Check that timeout was passed correctly assert mock_wait_for.call_args[1]['timeout'] == 5.0 # Should still call destroy in finally block assert mock_dispatcher.destroy.call_count >= 1 @pytest.mark.asyncio async def test_handle_unauthorized_request(): """A bearer that the IAM layer rejects causes the handshake to fail with 401. IamAuth surfaces an HTTPUnauthorized; the endpoint propagates it. Note that the endpoint intentionally does NOT distinguish 'bad token', 'expired', 'revoked', etc. — that's the IAM error-masking policy.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized( text='{"error":"auth failure"}', content_type="application/json", )) socket_endpoint = SocketEndpoint( "/test", mock_auth, AsyncMock(), capability=TEST_CAP, ) request = MagicMock() request.query = {"token": "invalid-token"} result = await socket_endpoint.handle(request) assert isinstance(result, web.HTTPUnauthorized) # authenticate must have been invoked with a synthetic request # carrying Bearer . The endpoint wraps the query- # string token into an Authorization header for a uniform auth # path — the IAM layer does not look at query strings directly. mock_auth.authenticate.assert_called_once() passed_req = mock_auth.authenticate.call_args.args[0] assert passed_req.headers["Authorization"] == "Bearer invalid-token" @pytest.mark.asyncio async def test_handle_missing_token(): """Request with no ``token`` query param → 401 before any IAM call is made (cheap short-circuit).""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock( side_effect=AssertionError( "authenticate must not be invoked when no token is present" ), ) socket_endpoint = SocketEndpoint( "/test", mock_auth, AsyncMock(), capability=TEST_CAP, ) request = MagicMock() request.query = {} # No token result = await socket_endpoint.handle(request) assert isinstance(result, web.HTTPUnauthorized) mock_auth.authenticate.assert_not_called() @pytest.mark.asyncio async def test_handle_websocket_already_closed(): """Test handling when websocket is already closed.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) mock_auth.authorise = AsyncMock(return_value=None) mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher socket_endpoint = SocketEndpoint( "/test", mock_auth, mock_dispatcher_factory, capability=TEST_CAP, ) request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = True # Already closed mock_ws_class.return_value = mock_ws with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(return_value=None) # Create proper mock tasks that look like asyncio.Task objects def create_task_mock(coro): # Consume the coroutine to avoid "was never awaited" warning coro.close() task = AsyncMock() task.done = MagicMock(return_value=True) task.cancelled = MagicMock(return_value=False) return task mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg result = await socket_endpoint.handle(request) # Should still have called destroy mock_dispatcher.destroy.assert_called() # Should not attempt to close already closed websocket mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True