trustgraph/tests/unit/test_gateway/test_socket_graceful_shutdown.py
cybermaggedon 5e28d3cce0
refactor(iam): pluggable IAM regime via authenticate/authorise contract (#853)
The gateway no longer holds any policy state — capability sets, role
definitions, workspace scope rules.  Per the IAM contract it asks the
regime "may this identity perform this capability on this resource?"
per request.  That moves the OSS role-based regime entirely into
iam-svc, which can be replaced (SSO, ABAC, ReBAC) without changing
the gateway, the wire protocol, or backend services.

Contract:
- authenticate(credential) -> Identity (handle, workspace,
  principal_id, source).  No roles, claims, or policy state surface
  to the gateway.
- authorise(identity, capability, resource, parameters) -> (allow,
  ttl).  Cached per-decision (regime TTL clamped above; fail-closed
  on regime errors).
- authorise_many available as a fan-out variant.

Operation registry drives every authorisation decision:
- /api/v1/iam -> IamEndpoint, looks up bare op name (create-user,
  list-workspaces, ...).
- /api/v1/{kind} -> RegistryRoutedVariableEndpoint, <kind>:<op>
  (config:get, flow:list-blueprints, librarian:add-document, ...).
- /api/v1/flow/{flow}/service/{kind} -> flow-service:<kind>.
- /api/v1/flow/{flow}/{import,export}/{kind} ->
  flow-{import,export}:<kind>.
- WS Mux per-frame -> flow-service:<kind>; closes a gap where
  authenticated users could hit any service kind.
85 operations registered across the surface.

JWT carries identity only — sub + workspace.  The roles claim is gone;
the gateway never reads policy state from a credential.

The three coarse *_KIND_CAPABILITY maps are removed.  The registry is
the only source of truth for the capability + resource shape of an
operation.  Tests migrated to the new Identity shape and to
authorise()-mocked auth doubles.

Specs updated: docs/tech-specs/iam-contract.md (Identity surface,
caching, registry-naming conventions), iam.md (JWT shape, gateway
flow, role section reframed as OSS-regime detail), iam-protocol.md
(positioned as one implementation of the contract).
2026-04-28 16:19:41 +01:00

442 lines
No EOL
15 KiB
Python

"""Unit tests for SocketEndpoint graceful shutdown functionality.
These tests exercise SocketEndpoint in its handshake-auth
configuration (``in_band_auth=False``) — the mode used in production
for the flow import/export streaming endpoints. The mux socket at
``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the
handshake always accepts and authentication runs on the first
WebSocket frame; that path is covered by the Mux tests.
Every endpoint constructor here passes an explicit capability — no
permissive default is relied upon.
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp import web, WSMsgType
from trustgraph.gateway.endpoint.socket import SocketEndpoint
from trustgraph.gateway.running import Running
from trustgraph.gateway.auth import Identity
# Representative capability used across these tests — corresponds to
# the flow-import streaming endpoint pattern that uses this class.
TEST_CAP = "graph:write"
def _valid_identity():
return Identity(
handle="test-user",
workspace="default",
principal_id="test-user",
source="api-key",
)
@pytest.fixture
def mock_auth():
"""Mock IAM-backed authenticator. Successful by default —
``authenticate`` returns a valid identity and ``authorise``
allows everything. Tests that need the failure paths override
the relevant attribute locally."""
auth = MagicMock()
auth.authenticate = AsyncMock(return_value=_valid_identity())
auth.authorise = AsyncMock(return_value=None)
return auth
@pytest.fixture
def mock_dispatcher_factory():
"""Mock dispatcher factory function."""
async def dispatcher_factory(ws, running, match_info):
dispatcher = AsyncMock()
dispatcher.run = AsyncMock()
dispatcher.receive = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
return dispatcher_factory
@pytest.fixture
def socket_endpoint(mock_auth, mock_dispatcher_factory):
"""Create SocketEndpoint for testing."""
return SocketEndpoint(
endpoint_path="/test-socket",
auth=mock_auth,
dispatcher=mock_dispatcher_factory,
capability=TEST_CAP,
)
@pytest.fixture
def mock_websocket():
"""Mock websocket response."""
ws = AsyncMock(spec=web.WebSocketResponse)
ws.prepare = AsyncMock()
ws.close = AsyncMock()
ws.closed = False
return ws
@pytest.fixture
def mock_request():
"""Mock HTTP request."""
request = MagicMock()
request.query = {"token": "test-token"}
request.match_info = {}
return request
@pytest.mark.asyncio
async def test_listener_graceful_shutdown_on_close():
"""Test listener handles websocket close gracefully."""
socket_endpoint = SocketEndpoint(
"/test", MagicMock(), AsyncMock(),
capability=TEST_CAP,
)
# Mock websocket that closes after one message
ws = AsyncMock()
# Create async iterator that yields one message then closes
async def mock_iterator(self):
# Yield normal message
msg = MagicMock()
msg.type = WSMsgType.TEXT
yield msg
# Yield close message
close_msg = MagicMock()
close_msg.type = WSMsgType.CLOSE
yield close_msg
# Set the async iterator method
ws.__aiter__ = mock_iterator
dispatcher = AsyncMock()
running = Running()
with patch('asyncio.sleep') as mock_sleep:
await socket_endpoint.listener(ws, dispatcher, running)
# Should have processed one message
dispatcher.receive.assert_called_once()
# Should have initiated graceful shutdown
assert running.get() is False
# Should have slept for grace period
mock_sleep.assert_called_once_with(1.0)
@pytest.mark.asyncio
async def test_handle_normal_flow():
"""Valid bearer → handshake accepted, dispatcher created."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
dispatcher_created = False
async def mock_dispatcher_factory(ws, running, match_info):
nonlocal dispatcher_created
dispatcher_created = True
dispatcher = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
# Mock task group context manager
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None)
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request)
# Should have created dispatcher
assert dispatcher_created is True
# Should return websocket
assert result == mock_ws
@pytest.mark.asyncio
async def test_handle_exception_group_cleanup():
"""Test exception group triggers dispatcher cleanup."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
# Mock TaskGroup to raise ExceptionGroup
class TestException(Exception):
pass
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
# Make wait_for consume the coroutine passed to it
async def wait_for_side_effect(coro, timeout=None):
coro.close() # Consume the coroutine
return None
mock_wait_for.side_effect = wait_for_side_effect
result = await socket_endpoint.handle(request)
# Should have attempted graceful cleanup
mock_wait_for.assert_called_once()
# Should have called destroy in finally block
assert mock_dispatcher.destroy.call_count >= 1
# Should have closed websocket
mock_ws.close.assert_called()
@pytest.mark.asyncio
async def test_handle_dispatcher_cleanup_timeout():
"""Test dispatcher cleanup with timeout."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
# Mock dispatcher that takes long to destroy
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
# Mock TaskGroup to raise exception
exception_group = ExceptionGroup("Test", [Exception("test")])
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = False
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
# Mock asyncio.wait_for to raise TimeoutError
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
# Make wait_for consume the coroutine before raising
async def wait_for_timeout(coro, timeout=None):
coro.close() # Consume the coroutine
raise asyncio.TimeoutError("Cleanup timeout")
mock_wait_for.side_effect = wait_for_timeout
result = await socket_endpoint.handle(request)
# Should have attempted cleanup with timeout
mock_wait_for.assert_called_once()
# Check that timeout was passed correctly
assert mock_wait_for.call_args[1]['timeout'] == 5.0
# Should still call destroy in finally block
assert mock_dispatcher.destroy.call_count >= 1
@pytest.mark.asyncio
async def test_handle_unauthorized_request():
"""A bearer that the IAM layer rejects causes the handshake to
fail with 401. IamAuth surfaces an HTTPUnauthorized; the
endpoint propagates it. Note that the endpoint intentionally
does NOT distinguish 'bad token', 'expired', 'revoked', etc. —
that's the IAM error-masking policy."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized(
text='{"error":"auth failure"}',
content_type="application/json",
))
socket_endpoint = SocketEndpoint(
"/test", mock_auth, AsyncMock(),
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "invalid-token"}
result = await socket_endpoint.handle(request)
assert isinstance(result, web.HTTPUnauthorized)
# authenticate must have been invoked with a synthetic request
# carrying Bearer <the-token>. The endpoint wraps the query-
# string token into an Authorization header for a uniform auth
# path — the IAM layer does not look at query strings directly.
mock_auth.authenticate.assert_called_once()
passed_req = mock_auth.authenticate.call_args.args[0]
assert passed_req.headers["Authorization"] == "Bearer invalid-token"
@pytest.mark.asyncio
async def test_handle_missing_token():
"""Request with no ``token`` query param → 401 before any
IAM call is made (cheap short-circuit)."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(
side_effect=AssertionError(
"authenticate must not be invoked when no token is present"
),
)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, AsyncMock(),
capability=TEST_CAP,
)
request = MagicMock()
request.query = {} # No token
result = await socket_endpoint.handle(request)
assert isinstance(result, web.HTTPUnauthorized)
mock_auth.authenticate.assert_not_called()
@pytest.mark.asyncio
async def test_handle_websocket_already_closed():
"""Test handling when websocket is already closed."""
mock_auth = MagicMock()
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
request.match_info = {}
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
mock_ws = AsyncMock()
mock_ws.prepare = AsyncMock()
mock_ws.close = AsyncMock()
mock_ws.closed = True # Already closed
mock_ws_class.return_value = mock_ws
with patch('asyncio.TaskGroup') as mock_task_group:
mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None)
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request)
# Should still have called destroy
mock_dispatcher.destroy.assert_called()
# Should not attempt to close already closed websocket
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True