mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-30 08:59:37 +02:00
Users are global entities, not scoped to workspaces. This change: Track A — Global usernames: - Change iam_users_by_username to PRIMARY KEY (username), removing workspace from the lookup key - Login looks up username globally, no workspace required - Username uniqueness is enforced globally, not per-workspace - Login -w now overrides the JWT workspace (session workspace) rather than selecting which user registry to search Track B — Rename workspace to default_workspace: - UserRecord.workspace → UserRecord.default_workspace - Identity.workspace → Identity.default_workspace - JWT claim "workspace" → "default_workspace" - IamResponse.resolved_workspace → resolved_default_workspace - WebSocket auth-ok frame field → default_workspace - Socket clients read default_workspace from auth-ok - _user_record_to_dict wire key → default_workspace - CLI help text and output updated throughout - Test files updated for renamed fields
442 lines
No EOL
15 KiB
Python
442 lines
No EOL
15 KiB
Python
"""Unit tests for SocketEndpoint graceful shutdown functionality.
|
|
|
|
These tests exercise SocketEndpoint in its handshake-auth
|
|
configuration (``in_band_auth=False``) — the mode used in production
|
|
for the flow import/export streaming endpoints. The mux socket at
|
|
``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the
|
|
handshake always accepts and authentication runs on the first
|
|
WebSocket frame; that path is covered by the Mux tests.
|
|
|
|
Every endpoint constructor here passes an explicit capability — no
|
|
permissive default is relied upon.
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from aiohttp import web, WSMsgType
|
|
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
|
from trustgraph.gateway.running import Running
|
|
from trustgraph.gateway.auth import Identity
|
|
|
|
|
|
# Representative capability used across these tests — corresponds to
|
|
# the flow-import streaming endpoint pattern that uses this class.
|
|
TEST_CAP = "graph:write"
|
|
|
|
|
|
def _valid_identity():
|
|
return Identity(
|
|
handle="test-user",
|
|
default_workspace="default",
|
|
principal_id="test-user",
|
|
source="api-key",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_auth():
|
|
"""Mock IAM-backed authenticator. Successful by default —
|
|
``authenticate`` returns a valid identity and ``authorise``
|
|
allows everything. Tests that need the failure paths override
|
|
the relevant attribute locally."""
|
|
auth = MagicMock()
|
|
auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
auth.authorise = AsyncMock(return_value=None)
|
|
return auth
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_dispatcher_factory():
|
|
"""Mock dispatcher factory function."""
|
|
async def dispatcher_factory(ws, running, match_info):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.run = AsyncMock()
|
|
dispatcher.receive = AsyncMock()
|
|
dispatcher.destroy = AsyncMock()
|
|
return dispatcher
|
|
|
|
return dispatcher_factory
|
|
|
|
|
|
@pytest.fixture
|
|
def socket_endpoint(mock_auth, mock_dispatcher_factory):
|
|
"""Create SocketEndpoint for testing."""
|
|
return SocketEndpoint(
|
|
endpoint_path="/test-socket",
|
|
auth=mock_auth,
|
|
dispatcher=mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_websocket():
|
|
"""Mock websocket response."""
|
|
ws = AsyncMock(spec=web.WebSocketResponse)
|
|
ws.prepare = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
ws.closed = False
|
|
return ws
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_request():
|
|
"""Mock HTTP request."""
|
|
request = MagicMock()
|
|
request.query = {"token": "test-token"}
|
|
request.match_info = {}
|
|
return request
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_listener_graceful_shutdown_on_close():
|
|
"""Test listener handles websocket close gracefully."""
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", MagicMock(), AsyncMock(),
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
# Mock websocket that closes after one message
|
|
ws = AsyncMock()
|
|
|
|
# Create async iterator that yields one message then closes
|
|
async def mock_iterator(self):
|
|
# Yield normal message
|
|
msg = MagicMock()
|
|
msg.type = WSMsgType.TEXT
|
|
yield msg
|
|
|
|
# Yield close message
|
|
close_msg = MagicMock()
|
|
close_msg.type = WSMsgType.CLOSE
|
|
yield close_msg
|
|
|
|
# Set the async iterator method
|
|
ws.__aiter__ = mock_iterator
|
|
|
|
dispatcher = AsyncMock()
|
|
running = Running()
|
|
|
|
with patch('asyncio.sleep') as mock_sleep:
|
|
await socket_endpoint.listener(ws, dispatcher, running)
|
|
|
|
# Should have processed one message
|
|
dispatcher.receive.assert_called_once()
|
|
|
|
# Should have initiated graceful shutdown
|
|
assert running.get() is False
|
|
|
|
# Should have slept for grace period
|
|
mock_sleep.assert_called_once_with(1.0)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_normal_flow():
|
|
"""Valid bearer → handshake accepted, dispatcher created."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
mock_auth.authorise = AsyncMock(return_value=None)
|
|
|
|
dispatcher_created = False
|
|
async def mock_dispatcher_factory(ws, running, match_info):
|
|
nonlocal dispatcher_created
|
|
dispatcher_created = True
|
|
dispatcher = AsyncMock()
|
|
dispatcher.destroy = AsyncMock()
|
|
return dispatcher
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "valid-token"}
|
|
request.match_info = {}
|
|
|
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
|
mock_ws = AsyncMock()
|
|
mock_ws.prepare = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
mock_ws.closed = False
|
|
mock_ws_class.return_value = mock_ws
|
|
|
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
|
# Mock task group context manager
|
|
mock_tg = AsyncMock()
|
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Create proper mock tasks that look like asyncio.Task objects
|
|
def create_task_mock(coro):
|
|
# Consume the coroutine to avoid "was never awaited" warning
|
|
coro.close()
|
|
task = AsyncMock()
|
|
task.done = MagicMock(return_value=True)
|
|
task.cancelled = MagicMock(return_value=False)
|
|
return task
|
|
|
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
|
mock_task_group.return_value = mock_tg
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
# Should have created dispatcher
|
|
assert dispatcher_created is True
|
|
|
|
# Should return websocket
|
|
assert result == mock_ws
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_exception_group_cleanup():
|
|
"""Test exception group triggers dispatcher cleanup."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
mock_auth.authorise = AsyncMock(return_value=None)
|
|
|
|
mock_dispatcher = AsyncMock()
|
|
mock_dispatcher.destroy = AsyncMock()
|
|
|
|
async def mock_dispatcher_factory(ws, running, match_info):
|
|
return mock_dispatcher
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "valid-token"}
|
|
request.match_info = {}
|
|
|
|
# Mock TaskGroup to raise ExceptionGroup
|
|
class TestException(Exception):
|
|
pass
|
|
|
|
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
|
|
|
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
|
mock_ws = AsyncMock()
|
|
mock_ws.prepare = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
mock_ws.closed = False
|
|
mock_ws_class.return_value = mock_ws
|
|
|
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
|
mock_tg = AsyncMock()
|
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
|
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
|
|
|
# Create proper mock tasks that look like asyncio.Task objects
|
|
def create_task_mock(coro):
|
|
# Consume the coroutine to avoid "was never awaited" warning
|
|
coro.close()
|
|
task = AsyncMock()
|
|
task.done = MagicMock(return_value=True)
|
|
task.cancelled = MagicMock(return_value=False)
|
|
return task
|
|
|
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
|
mock_task_group.return_value = mock_tg
|
|
|
|
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
|
|
# Make wait_for consume the coroutine passed to it
|
|
async def wait_for_side_effect(coro, timeout=None):
|
|
coro.close() # Consume the coroutine
|
|
return None
|
|
mock_wait_for.side_effect = wait_for_side_effect
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
# Should have attempted graceful cleanup
|
|
mock_wait_for.assert_called_once()
|
|
|
|
# Should have called destroy in finally block
|
|
assert mock_dispatcher.destroy.call_count >= 1
|
|
|
|
# Should have closed websocket
|
|
mock_ws.close.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_dispatcher_cleanup_timeout():
|
|
"""Test dispatcher cleanup with timeout."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
mock_auth.authorise = AsyncMock(return_value=None)
|
|
|
|
# Mock dispatcher that takes long to destroy
|
|
mock_dispatcher = AsyncMock()
|
|
mock_dispatcher.destroy = AsyncMock()
|
|
|
|
async def mock_dispatcher_factory(ws, running, match_info):
|
|
return mock_dispatcher
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "valid-token"}
|
|
request.match_info = {}
|
|
|
|
# Mock TaskGroup to raise exception
|
|
exception_group = ExceptionGroup("Test", [Exception("test")])
|
|
|
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
|
mock_ws = AsyncMock()
|
|
mock_ws.prepare = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
mock_ws.closed = False
|
|
mock_ws_class.return_value = mock_ws
|
|
|
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
|
mock_tg = AsyncMock()
|
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
|
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
|
|
|
# Create proper mock tasks that look like asyncio.Task objects
|
|
def create_task_mock(coro):
|
|
# Consume the coroutine to avoid "was never awaited" warning
|
|
coro.close()
|
|
task = AsyncMock()
|
|
task.done = MagicMock(return_value=True)
|
|
task.cancelled = MagicMock(return_value=False)
|
|
return task
|
|
|
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
|
mock_task_group.return_value = mock_tg
|
|
|
|
# Mock asyncio.wait_for to raise TimeoutError
|
|
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
|
|
# Make wait_for consume the coroutine before raising
|
|
async def wait_for_timeout(coro, timeout=None):
|
|
coro.close() # Consume the coroutine
|
|
raise asyncio.TimeoutError("Cleanup timeout")
|
|
mock_wait_for.side_effect = wait_for_timeout
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
# Should have attempted cleanup with timeout
|
|
mock_wait_for.assert_called_once()
|
|
# Check that timeout was passed correctly
|
|
assert mock_wait_for.call_args[1]['timeout'] == 5.0
|
|
|
|
# Should still call destroy in finally block
|
|
assert mock_dispatcher.destroy.call_count >= 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_unauthorized_request():
|
|
"""A bearer that the IAM layer rejects causes the handshake to
|
|
fail with 401. IamAuth surfaces an HTTPUnauthorized; the
|
|
endpoint propagates it. Note that the endpoint intentionally
|
|
does NOT distinguish 'bad token', 'expired', 'revoked', etc. —
|
|
that's the IAM error-masking policy."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized(
|
|
text='{"error":"auth failure"}',
|
|
content_type="application/json",
|
|
))
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, AsyncMock(),
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "invalid-token"}
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
assert isinstance(result, web.HTTPUnauthorized)
|
|
# authenticate must have been invoked with a synthetic request
|
|
# carrying Bearer <the-token>. The endpoint wraps the query-
|
|
# string token into an Authorization header for a uniform auth
|
|
# path — the IAM layer does not look at query strings directly.
|
|
mock_auth.authenticate.assert_called_once()
|
|
passed_req = mock_auth.authenticate.call_args.args[0]
|
|
assert passed_req.headers["Authorization"] == "Bearer invalid-token"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_missing_token():
|
|
"""Request with no ``token`` query param → 401 before any
|
|
IAM call is made (cheap short-circuit)."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(
|
|
side_effect=AssertionError(
|
|
"authenticate must not be invoked when no token is present"
|
|
),
|
|
)
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, AsyncMock(),
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {} # No token
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
assert isinstance(result, web.HTTPUnauthorized)
|
|
mock_auth.authenticate.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_websocket_already_closed():
|
|
"""Test handling when websocket is already closed."""
|
|
mock_auth = MagicMock()
|
|
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
|
mock_auth.authorise = AsyncMock(return_value=None)
|
|
|
|
mock_dispatcher = AsyncMock()
|
|
mock_dispatcher.destroy = AsyncMock()
|
|
|
|
async def mock_dispatcher_factory(ws, running, match_info):
|
|
return mock_dispatcher
|
|
|
|
socket_endpoint = SocketEndpoint(
|
|
"/test", mock_auth, mock_dispatcher_factory,
|
|
capability=TEST_CAP,
|
|
)
|
|
|
|
request = MagicMock()
|
|
request.query = {"token": "valid-token"}
|
|
request.match_info = {}
|
|
|
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
|
mock_ws = AsyncMock()
|
|
mock_ws.prepare = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
mock_ws.closed = True # Already closed
|
|
mock_ws_class.return_value = mock_ws
|
|
|
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
|
mock_tg = AsyncMock()
|
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Create proper mock tasks that look like asyncio.Task objects
|
|
def create_task_mock(coro):
|
|
# Consume the coroutine to avoid "was never awaited" warning
|
|
coro.close()
|
|
task = AsyncMock()
|
|
task.done = MagicMock(return_value=True)
|
|
task.cancelled = MagicMock(return_value=False)
|
|
return task
|
|
|
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
|
mock_task_group.return_value = mock_tg
|
|
|
|
result = await socket_endpoint.handle(request)
|
|
|
|
# Should still have called destroy
|
|
mock_dispatcher.destroy.assert_called()
|
|
|
|
# Should not attempt to close already closed websocket
|
|
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True |