trustgraph/tests/unit/test_gateway/test_socket_graceful_shutdown.py
cybermaggedon 0a828379be
feat: global usernames and rename workspace to default_workspace (#1001)
Users are global entities, not scoped to workspaces. This change:

Track A — Global usernames:
- Change iam_users_by_username to PRIMARY KEY (username), removing
  workspace from the lookup key
- Login looks up username globally, no workspace required
- Username uniqueness is enforced globally, not per-workspace
- Login -w now overrides the JWT workspace (session workspace)
  rather than selecting which user registry to search

Track B — Rename workspace to default_workspace:
- UserRecord.workspace → UserRecord.default_workspace
- Identity.workspace → Identity.default_workspace
- JWT claim "workspace" → "default_workspace"
- IamResponse.resolved_workspace → resolved_default_workspace
- WebSocket auth-ok frame field → default_workspace
- Socket clients read default_workspace from auth-ok
- _user_record_to_dict wire key → default_workspace
- CLI help text and output updated throughout
- Test files updated for renamed fields
2026-06-25 16:34:31 +01:00

442 lines
No EOL
15 KiB
Python

"""Unit tests for SocketEndpoint graceful shutdown functionality.
These tests exercise SocketEndpoint in its handshake-auth
configuration (``in_band_auth=False``) — the mode used in production
for the flow import/export streaming endpoints. The mux socket at
``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the
handshake always accepts and authentication runs on the first
WebSocket frame; that path is covered by the Mux tests.
Every endpoint constructor here passes an explicit capability — no
permissive default is relied upon.
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp import web, WSMsgType
from trustgraph.gateway.endpoint.socket import SocketEndpoint
from trustgraph.gateway.running import Running
from trustgraph.gateway.auth import Identity
# Representative capability used across these tests — corresponds to
# the flow-import streaming endpoint pattern that uses this class.
TEST_CAP = "graph:write"
def _valid_identity():
return Identity(
handle="test-user",
default_workspace="default",
principal_id="test-user",
source="api-key",
)
@pytest.fixture
def mock_auth():
"""Mock IAM-backed authenticator. Successful by default —
``authenticate`` returns a valid identity and ``authorise``
allows everything. Tests that need the failure paths override
the relevant attribute locally."""
auth = MagicMock()
auth.authenticate = AsyncMock(return_value=_valid_identity())
auth.authorise = AsyncMock(return_value=None)
return auth
@pytest.fixture
def mock_dispatcher_factory():
"""Mock dispatcher factory function."""
async def dispatcher_factory(ws, running, match_info):
dispatcher = AsyncMock()
dispatcher.run = AsyncMock()
dispatcher.receive = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
return dispatcher_factory
@pytest.fixture
def socket_endpoint(mock_auth, mock_dispatcher_factory):
"""Create SocketEndpoint for testing."""
return SocketEndpoint(
endpoint_path="/test-socket",
auth=mock_auth,
dispatcher=mock_dispatcher_factory,
capability=TEST_CAP,
)
@pytest.fixture
def mock_websocket():
"""Mock websocket response."""
ws = AsyncMock(spec=web.WebSocketResponse)
ws.prepare = AsyncMock()
ws.close = AsyncMock()
ws.closed = False
return ws
@pytest.fixture
def mock_request():
"""Mock HTTP request."""
request = MagicMock()
request.query = {"token": "test-token"}
request.match_info = {}
return request
@pytest.mark.asyncio
async def test_listener_graceful_shutdown_on_close():
"""Test listener handles websocket close gracefully."""
socket_endpoint = SocketEndpoint(
"/test", MagicMock(), AsyncMock(),
capability=TEST_CAP,
)
# Mock websocket that closes after one message
ws = AsyncMock()
# Create async iterator that yields one message then closes
async def mock_iterator(self):
# Yield normal message
msg = MagicMock()
msg.type = WSMsgType.TEXT
yield msg
# Yield close message
close_msg = MagicMock()
close_msg.type = WSMsgType.CLOSE
yield close_msg
# Set the async iterator method
ws.__aiter__ = mock_iterator
dispatcher = AsyncMock()
running = Running()
with patch('asyncio.sleep') as mock_sleep:
await socket_endpoint.listener(ws, dispatcher, running)
# Should have processed one message
dispatcher.receive.assert_called_once()
# Should have initiated graceful shutdown
assert running.get() is False
# Should have slept for grace period
mock_sleep.assert_called_once_with(1.0)
@pytest.mark.asyncio
async def test_handle_normal_flow():
"""Valid bearer → handshake accepted, dispatcher created."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
dispatcher_created = False
async def mock_dispatcher_factory(ws, running, match_info):
nonlocal dispatcher_created
dispatcher_created = True
dispatcher = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
# Mock task group context manager
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None)
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request)
# Should have created dispatcher
assert dispatcher_created is True
# Should return websocket
assert result == mock_ws
@pytest.mark.asyncio
async def test_handle_exception_group_cleanup():
"""Test exception group triggers dispatcher cleanup."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
# Mock TaskGroup to raise ExceptionGroup
class TestException(Exception):
pass
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
# Make wait_for consume the coroutine passed to it
async def wait_for_side_effect(coro, timeout=None):
coro.close() # Consume the coroutine
return None
mock_wait_for.side_effect = wait_for_side_effect
result = await socket_endpoint.handle(request)
# Should have attempted graceful cleanup
mock_wait_for.assert_called_once()
# Should have called destroy in finally block
assert mock_dispatcher.destroy.call_count >= 1
# Should have closed websocket
mock_ws.close.assert_called()
@pytest.mark.asyncio
async def test_handle_dispatcher_cleanup_timeout():
"""Test dispatcher cleanup with timeout."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
# Mock dispatcher that takes long to destroy
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
# Mock TaskGroup to raise exception
exception_group = ExceptionGroup("Test", [Exception("test")])
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
# Mock asyncio.wait_for to raise TimeoutError
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
# Make wait_for consume the coroutine before raising
async def wait_for_timeout(coro, timeout=None):
coro.close() # Consume the coroutine
raise asyncio.TimeoutError("Cleanup timeout")
mock_wait_for.side_effect = wait_for_timeout
result = await socket_endpoint.handle(request)
# Should have attempted cleanup with timeout
mock_wait_for.assert_called_once()
# Check that timeout was passed correctly
assert mock_wait_for.call_args[1]['timeout'] == 5.0
# Should still call destroy in finally block
assert mock_dispatcher.destroy.call_count >= 1
@pytest.mark.asyncio
async def test_handle_unauthorized_request():
"""A bearer that the IAM layer rejects causes the handshake to
fail with 401. IamAuth surfaces an HTTPUnauthorized; the
endpoint propagates it. Note that the endpoint intentionally
does NOT distinguish 'bad token', 'expired', 'revoked', etc. —
that's the IAM error-masking policy."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized(
text='{"error":"auth failure"}',
content_type="application/json",
))
socket_endpoint = SocketEndpoint(
"/test", mock_auth, AsyncMock(),
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "invalid-token"}
result = await socket_endpoint.handle(request)
assert isinstance(result, web.HTTPUnauthorized)
# authenticate must have been invoked with a synthetic request
# carrying Bearer <the-token>. The endpoint wraps the query-
# string token into an Authorization header for a uniform auth
# path — the IAM layer does not look at query strings directly.
mock_auth.authenticate.assert_called_once()
passed_req = mock_auth.authenticate.call_args.args[0]
assert passed_req.headers["Authorization"] == "Bearer invalid-token"
@pytest.mark.asyncio
async def test_handle_missing_token():
"""Request with no ``token`` query param → 401 before any
IAM call is made (cheap short-circuit)."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(
side_effect=AssertionError(
"authenticate must not be invoked when no token is present"
),
)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, AsyncMock(),
capability=TEST_CAP,
)
request = MagicMock()
request.query = {} # No token
result = await socket_endpoint.handle(request)
assert isinstance(result, web.HTTPUnauthorized)
mock_auth.authenticate.assert_not_called()
@pytest.mark.asyncio
async def test_handle_websocket_already_closed():
"""Test handling when websocket is already closed."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = True # Already closed
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None)
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request)
# Should still have called destroy
mock_dispatcher.destroy.assert_called()
# Should not attempt to close already closed websocket
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True