mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 18:36:22 +02:00
The gateway no longer holds any policy state — capability sets, role
definitions, workspace scope rules. Per the IAM contract it asks the
regime "may this identity perform this capability on this resource?"
per request. That moves the OSS role-based regime entirely into
iam-svc, which can be replaced (SSO, ABAC, ReBAC) without changing
the gateway, the wire protocol, or backend services.
Contract:
- authenticate(credential) -> Identity (handle, workspace,
principal_id, source). No roles, claims, or policy state surface
to the gateway.
- authorise(identity, capability, resource, parameters) -> (allow,
ttl). Cached per-decision (regime TTL clamped above; fail-closed
on regime errors).
- authorise_many available as a fan-out variant.
Operation registry drives every authorisation decision:
- /api/v1/iam -> IamEndpoint, looks up bare op name (create-user,
list-workspaces, ...).
- /api/v1/{kind} -> RegistryRoutedVariableEndpoint, <kind>:<op>
(config:get, flow:list-blueprints, librarian:add-document, ...).
- /api/v1/flow/{flow}/service/{kind} -> flow-service:<kind>.
- /api/v1/flow/{flow}/{import,export}/{kind} ->
flow-{import,export}:<kind>.
- WS Mux per-frame -> flow-service:<kind>; closes a gap where
authenticated users could hit any service kind.
85 operations registered across the surface.
JWT carries identity only — sub + workspace. The roles claim is gone;
the gateway never reads policy state from a credential.
The three coarse *_KIND_CAPABILITY maps are removed. The registry is
the only source of truth for the capability + resource shape of an
operation. Tests migrated to the new Identity shape and to
authorise()-mocked auth doubles.
Specs updated: docs/tech-specs/iam-contract.md (Identity surface,
caching, registry-naming conventions), iam.md (JWT shape, gateway
flow, role section reframed as OSS-regime detail), iam-protocol.md
(positioned as one implementation of the contract).
442 lines
No EOL
15 KiB
Python
442 lines
No EOL
15 KiB
Python
"""Unit tests for SocketEndpoint graceful shutdown functionality.
|
|
|
|
These tests exercise SocketEndpoint in its handshake-auth
|
|
configuration (``in_band_auth=False``) — the mode used in production
|
|
for the flow import/export streaming endpoints. The mux socket at
|
|
``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the
|
|
handshake always accepts and authentication runs on the first
|
|
WebSocket frame; that path is covered by the Mux tests.
|
|
|
|
Every endpoint constructor here passes an explicit capability — no
|
|
permissive default is relied upon.
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from aiohttp import web, WSMsgType
|
|
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
|
from trustgraph.gateway.running import Running
|
|
from trustgraph.gateway.auth import Identity
|
|
|
|
|
|
# Representative capability used across these tests — corresponds to
|
|
# the flow-import streaming endpoint pattern that uses this class.
|
|
TEST_CAP = "graph:write"
|
|
|
|
|
|
def _valid_identity():
|
|
return Identity(
|
|
handle="test-user",
|
|
workspace="default",
|
|
principal_id="test-user",
|
|
source="api-key",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_auth():
|
|
"""Mock IAM-backed authenticator. Successful by default —
|
|
``authenticate`` returns a valid identity and ``authorise``
|
|
allows everything. Tests that need the failure paths override
|
|
the relevant attribute locally."""
|
|
auth = MagicMock()
|
|
auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
auth.authorise = AsyncMock(return_value=None)
|
|
return auth
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_dispatcher_factory():
|
|
"""Mock dispatcher factory function."""
|
|
async def dispatcher_factory(ws, running, match_info):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.run = AsyncMock()
|
|
dispatcher.receive = AsyncMock()
|
|
dispatcher.destroy = AsyncMock()
|
|
return dispatcher
|
|
|
|
return dispatcher_factory
|
|
|
|
|
|
@pytest.fixture
|
|
def socket_endpoint(mock_auth, mock_dispatcher_factory):
|
|
"""Create SocketEndpoint for testing."""
|
|
return SocketEndpoint(
|
|
endpoint_path="/test-socket",
|
|
auth=mock_auth,
|
|
dispatcher=mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_websocket():
|
|
"""Mock websocket response."""
|
|
ws = AsyncMock(spec=web.WebSocketResponse)
|
|
ws.prepare = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
ws.closed = False
|
|
return ws
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_request():
|
|
"""Mock HTTP request."""
|
|
request = MagicMock()
|
|
request.query = {"token": "test-token"}
|
|
request.match_info = {}
|
|
return request
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_listener_graceful_shutdown_on_close():
|
|
"""Test listener handles websocket close gracefully."""
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", MagicMock(), AsyncMock(),
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
# Mock websocket that closes after one message
|
|
ws = AsyncMock()
|
|
|
|
# Create async iterator that yields one message then closes
|
|
async def mock_iterator(self):
|
|
# Yield normal message
|
|
msg = MagicMock()
|
|
msg.type = WSMsgType.TEXT
|
|
yield msg
|
|
|
|
# Yield close message
|
|
close_msg = MagicMock()
|
|
close_msg.type = WSMsgType.CLOSE
|
|
yield close_msg
|
|
|
|
# Set the async iterator method
|
|
ws.__aiter__ = mock_iterator
|
|
|
|
dispatcher = AsyncMock()
|
|
running = Running()
|
|
|
|
with patch('asyncio.sleep') as mock_sleep:
|
|
await socket_endpoint.listener(ws, dispatcher, running)
|
|
|
|
# Should have processed one message
|
|
dispatcher.receive.assert_called_once()
|
|
|
|
# Should have initiated graceful shutdown
|
|
assert running.get() is False
|
|
|
|
# Should have slept for grace period
|
|
mock_sleep.assert_called_once_with(1.0)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_normal_flow():
|
|
"""Valid bearer → handshake accepted, dispatcher created."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
mock_auth.authorise = AsyncMock(return_value=None)
|
|
|
|
dispatcher_created = False
|
|
async def mock_dispatcher_factory(ws, running, match_info):
|
|
nonlocal dispatcher_created
|
|
dispatcher_created = True
|
|
dispatcher = AsyncMock()
|
|
dispatcher.destroy = AsyncMock()
|
|
return dispatcher
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "valid-token"}
|
|
request.match_info = {}
|
|
|
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
|
mock_ws = AsyncMock()
|
|
mock_ws.prepare = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
mock_ws.closed = False
|
|
mock_ws_class.return_value = mock_ws
|
|
|
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
|
# Mock task group context manager
|
|
mock_tg = AsyncMock()
|
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Create proper mock tasks that look like asyncio.Task objects
|
|
def create_task_mock(coro):
|
|
# Consume the coroutine to avoid "was never awaited" warning
|
|
coro.close()
|
|
task = AsyncMock()
|
|
task.done = MagicMock(return_value=True)
|
|
task.cancelled = MagicMock(return_value=False)
|
|
return task
|
|
|
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
|
mock_task_group.return_value = mock_tg
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
# Should have created dispatcher
|
|
assert dispatcher_created is True
|
|
|
|
# Should return websocket
|
|
assert result == mock_ws
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_exception_group_cleanup():
|
|
"""Test exception group triggers dispatcher cleanup."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
mock_auth.authorise = AsyncMock(return_value=None)
|
|
|
|
mock_dispatcher = AsyncMock()
|
|
mock_dispatcher.destroy = AsyncMock()
|
|
|
|
async def mock_dispatcher_factory(ws, running, match_info):
|
|
return mock_dispatcher
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "valid-token"}
|
|
request.match_info = {}
|
|
|
|
# Mock TaskGroup to raise ExceptionGroup
|
|
class TestException(Exception):
|
|
pass
|
|
|
|
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
|
|
|
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
|
mock_ws = AsyncMock()
|
|
mock_ws.prepare = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
mock_ws.closed = False
|
|
mock_ws_class.return_value = mock_ws
|
|
|
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
|
mock_tg = AsyncMock()
|
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
|
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
|
|
|
# Create proper mock tasks that look like asyncio.Task objects
|
|
def create_task_mock(coro):
|
|
# Consume the coroutine to avoid "was never awaited" warning
|
|
coro.close()
|
|
task = AsyncMock()
|
|
task.done = MagicMock(return_value=True)
|
|
task.cancelled = MagicMock(return_value=False)
|
|
return task
|
|
|
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
|
mock_task_group.return_value = mock_tg
|
|
|
|
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
|
|
# Make wait_for consume the coroutine passed to it
|
|
async def wait_for_side_effect(coro, timeout=None):
|
|
coro.close() # Consume the coroutine
|
|
return None
|
|
mock_wait_for.side_effect = wait_for_side_effect
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
# Should have attempted graceful cleanup
|
|
mock_wait_for.assert_called_once()
|
|
|
|
# Should have called destroy in finally block
|
|
assert mock_dispatcher.destroy.call_count >= 1
|
|
|
|
# Should have closed websocket
|
|
mock_ws.close.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_dispatcher_cleanup_timeout():
|
|
"""Test dispatcher cleanup with timeout."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
mock_auth.authorise = AsyncMock(return_value=None)
|
|
|
|
# Mock dispatcher that takes long to destroy
|
|
mock_dispatcher = AsyncMock()
|
|
mock_dispatcher.destroy = AsyncMock()
|
|
|
|
async def mock_dispatcher_factory(ws, running, match_info):
|
|
return mock_dispatcher
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "valid-token"}
|
|
request.match_info = {}
|
|
|
|
# Mock TaskGroup to raise exception
|
|
exception_group = ExceptionGroup("Test", [Exception("test")])
|
|
|
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
|
mock_ws = AsyncMock()
|
|
mock_ws.prepare = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
mock_ws.closed = False
|
|
mock_ws_class.return_value = mock_ws
|
|
|
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
|
mock_tg = AsyncMock()
|
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
|
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
|
|
|
# Create proper mock tasks that look like asyncio.Task objects
|
|
def create_task_mock(coro):
|
|
# Consume the coroutine to avoid "was never awaited" warning
|
|
coro.close()
|
|
task = AsyncMock()
|
|
task.done = MagicMock(return_value=True)
|
|
task.cancelled = MagicMock(return_value=False)
|
|
return task
|
|
|
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
|
mock_task_group.return_value = mock_tg
|
|
|
|
# Mock asyncio.wait_for to raise TimeoutError
|
|
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
|
|
# Make wait_for consume the coroutine before raising
|
|
async def wait_for_timeout(coro, timeout=None):
|
|
coro.close() # Consume the coroutine
|
|
raise asyncio.TimeoutError("Cleanup timeout")
|
|
mock_wait_for.side_effect = wait_for_timeout
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
# Should have attempted cleanup with timeout
|
|
mock_wait_for.assert_called_once()
|
|
# Check that timeout was passed correctly
|
|
assert mock_wait_for.call_args[1]['timeout'] == 5.0
|
|
|
|
# Should still call destroy in finally block
|
|
assert mock_dispatcher.destroy.call_count >= 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_unauthorized_request():
|
|
"""A bearer that the IAM layer rejects causes the handshake to
|
|
fail with 401. IamAuth surfaces an HTTPUnauthorized; the
|
|
endpoint propagates it. Note that the endpoint intentionally
|
|
does NOT distinguish 'bad token', 'expired', 'revoked', etc. —
|
|
that's the IAM error-masking policy."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized(
|
|
text='{"error":"auth failure"}',
|
|
content_type="application/json",
|
|
))
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, AsyncMock(),
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "invalid-token"}
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
assert isinstance(result, web.HTTPUnauthorized)
|
|
# authenticate must have been invoked with a synthetic request
|
|
# carrying Bearer <the-token>. The endpoint wraps the query-
|
|
# string token into an Authorization header for a uniform auth
|
|
# path — the IAM layer does not look at query strings directly.
|
|
mock_auth.authenticate.assert_called_once()
|
|
passed_req = mock_auth.authenticate.call_args.args[0]
|
|
assert passed_req.headers["Authorization"] == "Bearer invalid-token"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_missing_token():
|
|
"""Request with no ``token`` query param → 401 before any
|
|
IAM call is made (cheap short-circuit)."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(
|
|
side_effect=AssertionError(
|
|
"authenticate must not be invoked when no token is present"
|
|
),
|
|
)
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, AsyncMock(),
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {} # No token
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
assert isinstance(result, web.HTTPUnauthorized)
|
|
mock_auth.authenticate.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_websocket_already_closed():
|
|
"""Test handling when websocket is already closed."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
mock_auth.authorise = AsyncMock(return_value=None)
|
|
|
|
mock_dispatcher = AsyncMock()
|
|
mock_dispatcher.destroy = AsyncMock()
|
|
|
|
async def mock_dispatcher_factory(ws, running, match_info):
|
|
return mock_dispatcher
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "valid-token"}
|
|
request.match_info = {}
|
|
|
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
|
mock_ws = AsyncMock()
|
|
mock_ws.prepare = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
mock_ws.closed = True # Already closed
|
|
mock_ws_class.return_value = mock_ws
|
|
|
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
|
mock_tg = AsyncMock()
|
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Create proper mock tasks that look like asyncio.Task objects
|
|
def create_task_mock(coro):
|
|
# Consume the coroutine to avoid "was never awaited" warning
|
|
coro.close()
|
|
task = AsyncMock()
|
|
task.done = MagicMock(return_value=True)
|
|
task.cancelled = MagicMock(return_value=False)
|
|
return task
|
|
|
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
|
mock_task_group.return_value = mock_tg
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
# Should still have called destroy
|
|
mock_dispatcher.destroy.assert_called()
|
|
|
|
# Should not attempt to close already closed websocket
|
|
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True |