Update rev-gateway for IAM integration (#940)

service.py:
- Constructor takes **config (same pattern as api-gateway) instead
  of individual args
- Creates IamAuth and calls await self.auth.start() before the
  message loop
- Passes auth to both ConfigReceiver and MessageDispatcher
- Uses add_pubsub_args / add_logging_args instead of hand-rolled
  Pulsar args
- Passes timeout through

dispatcher.py:
- Accepts auth and timeout parameters
- Passes both to DispatcherManager — fixes the missing auth argument
  that would have crashed on startup

The remote end's requests now go through the same IAM authentication
path as api-gateway. Token validation, workspace resolution, and
permissions all work identically regardless of which direction
initiated the connection.

Fixed tests — the test now passes auth and timeout to MessageDispatcher
and verifies they're forwarded to DispatcherManager.

Update rev gateway dispatcher to align with IAM.  A "token" parameter
must be passed with each message.

Fix websocket relay to align with rev-gateway changes, conforms to
the api-gateway protocol.
This commit is contained in:
cybermaggedon 2026-05-19 21:45:43 +01:00 committed by GitHub
parent 4e3bd85abc
commit e57f4669e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 914 additions and 865 deletions

View file

@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch, ANY
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
class TestWebSocketResponder:
"""Test cases for WebSocketResponder class"""
def test_websocket_responder_initialization(self):
"""Test WebSocketResponder initialization"""
responder = WebSocketResponder()
assert responder.response is None
assert responder.completed is False
@pytest.mark.asyncio
async def test_websocket_responder_send_method(self):
"""Test WebSocketResponder send method"""
responder = WebSocketResponder()
test_response = {"data": "test response"}
# Call send method
await responder.send(test_response)
# Verify response was stored
assert responder.response == test_response
@pytest.mark.asyncio
async def test_websocket_responder_call_method(self):
"""Test WebSocketResponder __call__ method"""
responder = WebSocketResponder()
test_response = {"result": "success"}
test_completed = True
# Call the responder
await responder(test_response, test_completed)
# Verify response and completed status were set
assert responder.response == test_response
assert responder.completed == test_completed
@pytest.mark.asyncio
async def test_websocket_responder_call_method_with_false_completion(self):
"""Test WebSocketResponder __call__ method with incomplete response"""
responder = WebSocketResponder()
test_response = {"partial": "data"}
test_completed = False
# Call the responder
await responder(test_response, test_completed)
# Verify response was set and completed is True (since send() always sets completed=True)
assert responder.response == test_response
assert responder.completed is True
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestMessageDispatcher:
"""Test cases for MessageDispatcher class"""
def test_message_dispatcher_initialization_with_defaults(self):
"""Test MessageDispatcher initialization with default parameters"""
dispatcher = MessageDispatcher()
assert dispatcher.max_workers == 10
assert dispatcher.semaphore._value == 10
assert dispatcher.active_tasks == set()
assert dispatcher.backend is None
assert dispatcher.auth is None
assert dispatcher.dispatcher_manager is None
assert len(dispatcher.service_mapping) > 0
def test_message_dispatcher_initialization_with_custom_workers(self):
"""Test MessageDispatcher initialization with custom max_workers"""
dispatcher = MessageDispatcher(max_workers=5)
assert dispatcher.max_workers == 5
assert dispatcher.semaphore._value == 5
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
def test_message_dispatcher_initialization_with_backend(
self, mock_dispatcher_manager,
):
mock_backend = MagicMock()
mock_config_receiver = MagicMock()
mock_auth = MagicMock()
mock_dispatcher_instance = MagicMock()
mock_dispatcher_manager.return_value = mock_dispatcher_instance
dispatcher = MessageDispatcher(
max_workers=8,
config_receiver=mock_config_receiver,
backend=mock_backend
backend=mock_backend,
auth=mock_auth,
timeout=300,
)
assert dispatcher.max_workers == 8
assert dispatcher.backend == mock_backend
assert dispatcher.auth == mock_auth
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
mock_dispatcher_manager.assert_called_once_with(
mock_backend, mock_config_receiver, prefix="rev-gateway"
mock_backend, mock_config_receiver,
auth=mock_auth, prefix="rev-gateway", timeout=300,
)
def test_message_dispatcher_service_mapping(self):
"""Test MessageDispatcher service mapping contains expected services"""
dispatcher = MessageDispatcher()
expected_services = [
"text-completion", "graph-rag", "agent", "embeddings",
"graph-embeddings", "triples", "document-load", "text-load",
"flow", "knowledge", "config", "librarian", "document-rag"
"flow", "knowledge", "config", "librarian", "document-rag",
]
for service in expected_services:
assert service in dispatcher.service_mapping
# Test specific mappings
assert dispatcher.service_mapping["text-completion"] == "text-completion"
assert dispatcher.service_mapping["document-load"] == "document"
assert dispatcher.service_mapping["text-load"] == "text-document"
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
"""Test MessageDispatcher handle_message without dispatcher manager"""
async def test_handle_message_without_dispatcher_manager(self):
dispatcher = MessageDispatcher()
test_message = {
"id": "test-123",
"service": "test-service",
"request": {"data": "test"}
}
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-123"
assert "error" in result["response"]
assert "DispatcherManager not available" in result["response"]["error"]
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="default")
)
sender = AsyncMock()
await dispatcher.handle_message(
{"id": "test-1", "service": "test", "request": {}},
sender,
)
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-1"
assert sent["error"]["message"] == "DispatcherManager not available"
assert sent["error"]["type"] == "error"
assert sent["complete"] is True
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_with_exception(self):
"""Test MessageDispatcher handle_message with exception during processing"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
async def test_handle_message_auth_failure(self):
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-456",
"service": "text-completion",
"request": {"prompt": "test"}
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-456"
assert "error" in result["response"]
assert "Test error" in result["response"]["error"]
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
side_effect=Exception("auth failure")
)
dispatcher.dispatcher_manager = MagicMock()
sender = AsyncMock()
await dispatcher.handle_message(
{"id": "test-2", "token": "bad", "service": "test", "request": {}},
sender,
)
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-2"
assert "auth failure" in sent["error"]["message"]
assert sent["complete"] is True
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_global_service(self):
"""Test MessageDispatcher handle_message with global service"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"result": "success"}
async def test_handle_message_global_service(self):
mock_dm = MagicMock()
mock_dm.invoke_global_service = AsyncMock()
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-789",
"service": "text-completion",
"request": {"prompt": "hello"}
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-789"
assert result["response"] == {"result": "success"}
mock_dispatcher_manager.invoke_global_service.assert_called_once()
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws1")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-3",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hello"},
},
sender,
)
mock_dm.invoke_global_service.assert_called_once()
args, kwargs = mock_dm.invoke_global_service.call_args
assert args[0] == {"prompt": "hello"}
assert args[2] == "text-completion"
assert kwargs["workspace"] == "ws1"
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_flow_service(self):
"""Test MessageDispatcher handle_message with flow service"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"data": "flow_result"}
async def test_handle_message_flow_service(self):
mock_dm = MagicMock()
mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-flow-123",
"service": "document-rag",
"request": {"query": "test"},
"flow": "custom-flow"
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-flow-123"
assert result["response"] == {"data": "flow_result"}
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws2")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-4",
"token": "tg_key",
"service": "document-rag",
"request": {"query": "test"},
"flow": "my-flow",
},
sender,
)
mock_dm.invoke_flow_service.assert_called_once_with(
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
)
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_incomplete_response(self):
"""Test MessageDispatcher handle_message with incomplete response"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = False
mock_responder.response = None
async def test_handle_message_responder_sends_frames(self):
mock_dm = MagicMock()
async def fake_invoke(data, responder, svc, workspace=None):
await responder({"partial": 1}, False)
await responder({"partial": 2}, True)
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-incomplete",
"service": "agent",
"request": {"input": "test"}
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws1")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-5",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hi"},
},
sender,
)
assert sender.call_count == 2
first = sender.call_args_list[0][0][0]
second = sender.call_args_list[1][0][0]
assert first == {
"id": "test-5", "response": {"partial": 1}, "complete": False,
}
assert second == {
"id": "test-5", "response": {"partial": 2}, "complete": True,
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-incomplete"
assert result["response"] == {"error": "No response received"}
@pytest.mark.asyncio
async def test_message_dispatcher_shutdown(self):
"""Test MessageDispatcher shutdown method"""
import asyncio
async def test_handle_message_workspace_from_identity(self):
mock_dm = MagicMock()
mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher()
# Create actual async tasks
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="derived-ws")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-6",
"token": "tg_key",
"service": "agent",
"request": {"question": "test"},
"flow": "default",
},
sender,
)
args = mock_dm.invoke_flow_service.call_args[0]
assert args[2] == "derived-ws"
@pytest.mark.asyncio
async def test_shutdown(self):
dispatcher = MessageDispatcher()
async def dummy_task():
await asyncio.sleep(0.01)
return "done"
task1 = asyncio.create_task(dummy_task())
task2 = asyncio.create_task(dummy_task())
dispatcher.active_tasks = {task1, task2}
# Call shutdown
await dispatcher.shutdown()
# Verify tasks were completed
assert task1.done()
assert task2.done()
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
@pytest.mark.asyncio
async def test_message_dispatcher_shutdown_with_no_tasks(self):
"""Test MessageDispatcher shutdown with no active tasks"""
async def test_shutdown_with_no_tasks(self):
dispatcher = MessageDispatcher()
# Call shutdown with no active tasks
await dispatcher.shutdown()
# Should complete without error
assert dispatcher.active_tasks == set()
assert dispatcher.active_tasks == set()