Update rev-gateway for IAM integration (#940)

service.py:
- Constructor takes **config (same pattern as api-gateway) instead
  of individual args
- Creates IamAuth and calls await self.auth.start() before the
  message loop
- Passes auth to both ConfigReceiver and MessageDispatcher
- Uses add_pubsub_args / add_logging_args instead of hand-rolled
  Pulsar args
- Passes timeout through

dispatcher.py:
- Accepts auth and timeout parameters
- Passes both to DispatcherManager — fixes the missing auth argument
  that would have crashed on startup

The remote end's requests now go through the same IAM authentication
path as api-gateway. Token validation, workspace resolution, and
permissions all work identically regardless of which direction
initiated the connection.

Fixed tests — the test now passes auth and timeout to MessageDispatcher
and verifies they're forwarded to DispatcherManager.

Update rev gateway dispatcher to align with IAM.  A "token" parameter
must be passed with each message.

Fix websocket relay to align with rev-gateway changes, conforms to
the api-gateway protocol.
This commit is contained in:
cybermaggedon 2026-05-19 21:45:43 +01:00 committed by GitHub
parent 4e3bd85abc
commit e57f4669e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 914 additions and 865 deletions

View file

@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
max_concurrent = 0
processing_event = asyncio.Event()
async def slow_process(message):
async def slow_process(message, sender):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.05)
concurrent_count -= 1
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = slow_process
sender = AsyncMock()
# Launch more tasks than max_workers
messages = [
{"id": f"msg-{i}", "service": "test", "request": {}}
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
]
tasks = [
asyncio.create_task(dispatcher.handle_message(m))
asyncio.create_task(dispatcher.handle_message(m, sender))
for m in messages
]
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
original_process = dispatcher._process_message
async def tracking_process(message):
async def tracking_process(message, sender):
nonlocal task_was_tracked
# During processing, our task should be in active_tasks
if len(dispatcher.active_tasks) > 0:
task_was_tracked = True
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = tracking_process
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
{"id": "test", "service": "test", "request": {}},
AsyncMock(),
)
assert task_was_tracked
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
"""Semaphore should be released even if processing raises."""
dispatcher = MessageDispatcher(max_workers=2)
async def failing_process(message):
async def failing_process(message, sender):
raise RuntimeError("process failed")
dispatcher._process_message = failing_process
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
# Should not deadlock — semaphore must be released on error
with pytest.raises(RuntimeError):
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
{"id": "test", "service": "test", "request": {}},
AsyncMock(),
)
# Semaphore should be back at max
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
order = []
async def ordered_process(message):
async def ordered_process(message, sender):
msg_id = message["id"]
order.append(f"start-{msg_id}")
await asyncio.sleep(0.02)
order.append(f"end-{msg_id}")
return {"id": msg_id, "response": {"ok": True}}
dispatcher._process_message = ordered_process
sender = AsyncMock()
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
await asyncio.gather(*tasks)
# With semaphore=1, each message should complete before next starts

View file

View file

@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch, ANY
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
class TestWebSocketResponder:
"""Test cases for WebSocketResponder class"""
def test_websocket_responder_initialization(self):
"""Test WebSocketResponder initialization"""
responder = WebSocketResponder()
assert responder.response is None
assert responder.completed is False
@pytest.mark.asyncio
async def test_websocket_responder_send_method(self):
"""Test WebSocketResponder send method"""
responder = WebSocketResponder()
test_response = {"data": "test response"}
# Call send method
await responder.send(test_response)
# Verify response was stored
assert responder.response == test_response
@pytest.mark.asyncio
async def test_websocket_responder_call_method(self):
"""Test WebSocketResponder __call__ method"""
responder = WebSocketResponder()
test_response = {"result": "success"}
test_completed = True
# Call the responder
await responder(test_response, test_completed)
# Verify response and completed status were set
assert responder.response == test_response
assert responder.completed == test_completed
@pytest.mark.asyncio
async def test_websocket_responder_call_method_with_false_completion(self):
"""Test WebSocketResponder __call__ method with incomplete response"""
responder = WebSocketResponder()
test_response = {"partial": "data"}
test_completed = False
# Call the responder
await responder(test_response, test_completed)
# Verify response was set and completed is True (since send() always sets completed=True)
assert responder.response == test_response
assert responder.completed is True
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestMessageDispatcher:
"""Test cases for MessageDispatcher class"""
def test_message_dispatcher_initialization_with_defaults(self):
"""Test MessageDispatcher initialization with default parameters"""
dispatcher = MessageDispatcher()
assert dispatcher.max_workers == 10
assert dispatcher.semaphore._value == 10
assert dispatcher.active_tasks == set()
assert dispatcher.backend is None
assert dispatcher.auth is None
assert dispatcher.dispatcher_manager is None
assert len(dispatcher.service_mapping) > 0
def test_message_dispatcher_initialization_with_custom_workers(self):
"""Test MessageDispatcher initialization with custom max_workers"""
dispatcher = MessageDispatcher(max_workers=5)
assert dispatcher.max_workers == 5
assert dispatcher.semaphore._value == 5
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
def test_message_dispatcher_initialization_with_backend(
self, mock_dispatcher_manager,
):
mock_backend = MagicMock()
mock_config_receiver = MagicMock()
mock_auth = MagicMock()
mock_dispatcher_instance = MagicMock()
mock_dispatcher_manager.return_value = mock_dispatcher_instance
dispatcher = MessageDispatcher(
max_workers=8,
config_receiver=mock_config_receiver,
backend=mock_backend
backend=mock_backend,
auth=mock_auth,
timeout=300,
)
assert dispatcher.max_workers == 8
assert dispatcher.backend == mock_backend
assert dispatcher.auth == mock_auth
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
mock_dispatcher_manager.assert_called_once_with(
mock_backend, mock_config_receiver, prefix="rev-gateway"
mock_backend, mock_config_receiver,
auth=mock_auth, prefix="rev-gateway", timeout=300,
)
def test_message_dispatcher_service_mapping(self):
"""Test MessageDispatcher service mapping contains expected services"""
dispatcher = MessageDispatcher()
expected_services = [
"text-completion", "graph-rag", "agent", "embeddings",
"graph-embeddings", "triples", "document-load", "text-load",
"flow", "knowledge", "config", "librarian", "document-rag"
"flow", "knowledge", "config", "librarian", "document-rag",
]
for service in expected_services:
assert service in dispatcher.service_mapping
# Test specific mappings
assert dispatcher.service_mapping["text-completion"] == "text-completion"
assert dispatcher.service_mapping["document-load"] == "document"
assert dispatcher.service_mapping["text-load"] == "text-document"
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
"""Test MessageDispatcher handle_message without dispatcher manager"""
async def test_handle_message_without_dispatcher_manager(self):
dispatcher = MessageDispatcher()
test_message = {
"id": "test-123",
"service": "test-service",
"request": {"data": "test"}
}
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-123"
assert "error" in result["response"]
assert "DispatcherManager not available" in result["response"]["error"]
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="default")
)
sender = AsyncMock()
await dispatcher.handle_message(
{"id": "test-1", "service": "test", "request": {}},
sender,
)
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-1"
assert sent["error"]["message"] == "DispatcherManager not available"
assert sent["error"]["type"] == "error"
assert sent["complete"] is True
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_with_exception(self):
"""Test MessageDispatcher handle_message with exception during processing"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
async def test_handle_message_auth_failure(self):
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-456",
"service": "text-completion",
"request": {"prompt": "test"}
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-456"
assert "error" in result["response"]
assert "Test error" in result["response"]["error"]
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
side_effect=Exception("auth failure")
)
dispatcher.dispatcher_manager = MagicMock()
sender = AsyncMock()
await dispatcher.handle_message(
{"id": "test-2", "token": "bad", "service": "test", "request": {}},
sender,
)
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-2"
assert "auth failure" in sent["error"]["message"]
assert sent["complete"] is True
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_global_service(self):
"""Test MessageDispatcher handle_message with global service"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"result": "success"}
async def test_handle_message_global_service(self):
mock_dm = MagicMock()
mock_dm.invoke_global_service = AsyncMock()
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-789",
"service": "text-completion",
"request": {"prompt": "hello"}
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-789"
assert result["response"] == {"result": "success"}
mock_dispatcher_manager.invoke_global_service.assert_called_once()
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws1")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-3",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hello"},
},
sender,
)
mock_dm.invoke_global_service.assert_called_once()
args, kwargs = mock_dm.invoke_global_service.call_args
assert args[0] == {"prompt": "hello"}
assert args[2] == "text-completion"
assert kwargs["workspace"] == "ws1"
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_flow_service(self):
"""Test MessageDispatcher handle_message with flow service"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"data": "flow_result"}
async def test_handle_message_flow_service(self):
mock_dm = MagicMock()
mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-flow-123",
"service": "document-rag",
"request": {"query": "test"},
"flow": "custom-flow"
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-flow-123"
assert result["response"] == {"data": "flow_result"}
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws2")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-4",
"token": "tg_key",
"service": "document-rag",
"request": {"query": "test"},
"flow": "my-flow",
},
sender,
)
mock_dm.invoke_flow_service.assert_called_once_with(
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
)
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_incomplete_response(self):
"""Test MessageDispatcher handle_message with incomplete response"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = False
mock_responder.response = None
async def test_handle_message_responder_sends_frames(self):
mock_dm = MagicMock()
async def fake_invoke(data, responder, svc, workspace=None):
await responder({"partial": 1}, False)
await responder({"partial": 2}, True)
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-incomplete",
"service": "agent",
"request": {"input": "test"}
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws1")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-5",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hi"},
},
sender,
)
assert sender.call_count == 2
first = sender.call_args_list[0][0][0]
second = sender.call_args_list[1][0][0]
assert first == {
"id": "test-5", "response": {"partial": 1}, "complete": False,
}
assert second == {
"id": "test-5", "response": {"partial": 2}, "complete": True,
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-incomplete"
assert result["response"] == {"error": "No response received"}
@pytest.mark.asyncio
async def test_message_dispatcher_shutdown(self):
"""Test MessageDispatcher shutdown method"""
import asyncio
async def test_handle_message_workspace_from_identity(self):
mock_dm = MagicMock()
mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher()
# Create actual async tasks
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="derived-ws")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-6",
"token": "tg_key",
"service": "agent",
"request": {"question": "test"},
"flow": "default",
},
sender,
)
args = mock_dm.invoke_flow_service.call_args[0]
assert args[2] == "derived-ws"
@pytest.mark.asyncio
async def test_shutdown(self):
dispatcher = MessageDispatcher()
async def dummy_task():
await asyncio.sleep(0.01)
return "done"
task1 = asyncio.create_task(dummy_task())
task2 = asyncio.create_task(dummy_task())
dispatcher.active_tasks = {task1, task2}
# Call shutdown
await dispatcher.shutdown()
# Verify tasks were completed
assert task1.done()
assert task2.done()
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
@pytest.mark.asyncio
async def test_message_dispatcher_shutdown_with_no_tasks(self):
"""Test MessageDispatcher shutdown with no active tasks"""
async def test_shutdown_with_no_tasks(self):
dispatcher = MessageDispatcher()
# Call shutdown with no active tasks
await dispatcher.shutdown()
# Should complete without error
assert dispatcher.active_tasks == set()
assert dispatcher.active_tasks == set()

View file

@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock
from aiohttp import WSMsgType, ClientWebSocketResponse
import json
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run
from trustgraph.rev_gateway.service import ReverseGateway, run
MOCK_PATCHES = [
'trustgraph.rev_gateway.service.IamAuth',
'trustgraph.rev_gateway.service.ConfigReceiver',
'trustgraph.rev_gateway.service.MessageDispatcher',
'trustgraph.rev_gateway.service.get_pubsub',
]
def make_gateway(**overrides):
config = {"websocket_uri": "ws://localhost:7650/out"}
config.update(overrides)
return ReverseGateway(**config)
class TestReverseGateway:
"""Test cases for ReverseGateway class"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with default parameters"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_defaults(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
assert gateway.websocket_uri == "ws://localhost:7650/out"
assert gateway.host == "localhost"
assert gateway.port == 7650
@ -33,25 +49,22 @@ class TestReverseGateway:
assert gateway.max_workers == 10
assert gateway.running is False
assert gateway.reconnect_delay == 3.0
assert gateway.pulsar_host == "pulsar://pulsar:6650"
assert gateway.pulsar_api_key is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with custom parameters"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_custom_params(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(
websocket_uri="wss://example.com:8080/websocket",
max_workers=20,
pulsar_host="pulsar://custom:6650",
pulsar_api_key="test-key",
pulsar_listener="test-listener"
)
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
assert gateway.host == "example.com"
assert gateway.port == 8080
@ -59,340 +72,360 @@ class TestReverseGateway:
assert gateway.path == "/websocket"
assert gateway.url == "wss://example.com:8080/websocket"
assert gateway.max_workers == 20
assert gateway.pulsar_host == "pulsar://custom:6650"
assert gateway.pulsar_api_key == "test-key"
assert gateway.pulsar_listener == "test-listener"
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with WebSocket URI missing path"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(websocket_uri="ws://example.com")
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_with_missing_path(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(websocket_uri="ws://example.com")
assert gateway.path == "/ws"
assert gateway.url == "ws://example.com/ws"
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_invalid_scheme(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
ReverseGateway(websocket_uri="http://example.com")
make_gateway(websocket_uri="http://example.com")
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with missing hostname"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_missing_hostname(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
ReverseGateway(websocket_uri="ws://")
make_gateway(websocket_uri="ws://")
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway creates backend with authentication"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_iam_auth_created(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(
pulsar_api_key="test-key",
pulsar_listener="test-listener"
gateway = make_gateway(id="test-rev-gw")
mock_iam_auth.assert_called_once_with(
backend=mock_backend,
id="test-rev-gw",
)
# Verify get_pubsub was called with the correct parameters
mock_get_pubsub.assert_called_once_with(
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key="test-key",
pulsar_listener="test-listener"
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_config_receiver_gets_auth(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_auth_instance = MagicMock()
mock_iam_auth.return_value = mock_auth_instance
gateway = make_gateway()
mock_config_receiver.assert_called_once_with(
mock_backend, auth=mock_auth_instance,
)
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway successful connection"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_connect_success(
self, mock_session_class, mock_get_pubsub,
mock_dispatcher, mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock()
mock_ws = AsyncMock()
mock_session.ws_connect.return_value = mock_ws
mock_session_class.return_value = mock_session
gateway = ReverseGateway()
gateway = make_gateway()
result = await gateway.connect()
assert result is True
assert gateway.session == mock_session
assert gateway.ws == mock_ws
mock_session.ws_connect.assert_called_once_with(gateway.url)
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway connection failure"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_connect_failure(
self, mock_session_class, mock_get_pubsub,
mock_dispatcher, mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock()
mock_session.ws_connect.side_effect = Exception("Connection failed")
mock_session_class.return_value = mock_session
gateway = ReverseGateway()
gateway = make_gateway()
result = await gateway.connect()
assert result is False
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway disconnect"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock websocket and session
async def test_reverse_gateway_disconnect(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = False
mock_session = AsyncMock()
mock_session.closed = False
gateway.ws = mock_ws
gateway.session = mock_session
await gateway.disconnect()
mock_ws.close.assert_called_once()
mock_session.close.assert_called_once()
assert gateway.ws is None
assert gateway.session is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway send message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock websocket
async def test_reverse_gateway_send_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message)
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway send message with closed connection"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock closed websocket
async def test_reverse_gateway_send_message_closed_connection(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = True
gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message)
# Should not call send_str on closed connection
mock_ws.send_str.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock()
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway()
# Mock send_message
gateway.send_message = AsyncMock()
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
await gateway.handle_message(test_message)
mock_dispatcher_instance.handle_message.assert_called_once_with({
"id": "test",
"service": "test-service",
"request": {"data": "test"}
})
gateway.send_message.assert_called_once_with({"response": "success"})
async def test_reverse_gateway_handle_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = make_gateway()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message with invalid JSON"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock send_message
gateway.send_message = AsyncMock()
test_message = 'invalid json'
# Should not raise exception
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
await gateway.handle_message(test_message)
# Should not call send_message due to error
mock_dispatcher_instance.handle_message.assert_called_once_with(
{
"id": "test",
"service": "test-service",
"request": {"data": "test"},
},
gateway.send_message,
)
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.send_message = AsyncMock()
await gateway.handle_message('invalid json')
gateway.send_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with text message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_text_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.TEXT
mock_msg.data = '{"test": "message"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "message"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with binary message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_binary_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.BINARY
mock_msg.data = b'{"test": "binary"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with close message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_close_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.CLOSE
# Mock receive to return close message
mock_ws.receive.return_value = mock_msg
await gateway.listen()
# Should not call handle_message for close message
gateway.handle_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway shutdown"""
async def test_reverse_gateway_shutdown(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True
# Mock disconnect
gateway.disconnect = AsyncMock()
await gateway.shutdown()
@ -402,46 +435,50 @@ class TestReverseGateway:
gateway.disconnect.assert_called_once()
mock_backend.close.assert_called_once()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway stop"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_stop(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
gateway.stop()
assert gateway.running is False
class TestReverseGatewayRun:
"""Test cases for ReverseGateway run method"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway run method with successful connect/listen cycle"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_run_successful_cycle(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_auth_instance = AsyncMock()
mock_iam_auth.return_value = mock_auth_instance
mock_config_receiver_instance = AsyncMock()
mock_config_receiver.return_value = mock_config_receiver_instance
gateway = ReverseGateway()
# Mock methods
gateway.connect = AsyncMock(return_value=True)
gateway = make_gateway()
gateway.listen = AsyncMock()
gateway.disconnect = AsyncMock()
gateway.shutdown = AsyncMock()
# Stop after one iteration
call_count = 0
async def mock_connect():
nonlocal call_count
@ -451,91 +488,13 @@ class TestReverseGatewayRun:
else:
gateway.running = False
return False
gateway.connect = mock_connect
await gateway.run()
mock_auth_instance.start.assert_called_once()
mock_config_receiver_instance.start.assert_called_once()
gateway.listen.assert_called_once()
# disconnect is called twice: once in the main loop, once in shutdown
assert gateway.disconnect.call_count == 2
gateway.shutdown.assert_called_once()
class TestReverseGatewayArgs:
"""Test cases for argument parsing and run function"""
def test_parse_args_defaults(self):
"""Test parse_args with default values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway']
try:
args = parse_args()
assert args.websocket_uri is None
assert args.max_workers == 10
assert args.pulsar_host is None
assert args.pulsar_api_key is None
assert args.pulsar_listener is None
finally:
sys.argv = original_argv
def test_parse_args_custom_values(self):
"""Test parse_args with custom values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = [
'reverse-gateway',
'--websocket-uri', 'ws://custom:8080/ws',
'--max-workers', '20',
'--pulsar-host', 'pulsar://custom:6650',
'--pulsar-api-key', 'test-key',
'--pulsar-listener', 'test-listener'
]
try:
args = parse_args()
assert args.websocket_uri == 'ws://custom:8080/ws'
assert args.max_workers == 20
assert args.pulsar_host == 'pulsar://custom:6650'
assert args.pulsar_api_key == 'test-key'
assert args.pulsar_listener == 'test-listener'
finally:
sys.argv = original_argv
@patch('trustgraph.rev_gateway.service.ReverseGateway')
@patch('asyncio.run')
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
"""Test run function"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway', '--max-workers', '15']
try:
mock_gateway_instance = MagicMock()
mock_gateway_instance.url = "ws://localhost:7650/out"
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
mock_gateway_class.return_value = mock_gateway_instance
run()
mock_gateway_class.assert_called_once_with(
websocket_uri=None,
max_workers=15,
pulsar_host=None,
pulsar_api_key=None,
pulsar_listener=None
)
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
finally:
sys.argv = original_argv