From e57f4669e18b23d3c1fad93722757b1eee7d9fca Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 19 May 2026 21:45:43 +0100 Subject: [PATCH] Update rev-gateway for IAM integration (#940) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit service.py: - Constructor takes **config (same pattern as api-gateway) instead of individual args - Creates IamAuth and calls await self.auth.start() before the message loop - Passes auth to both ConfigReceiver and MessageDispatcher - Uses add_pubsub_args / add_logging_args instead of hand-rolled Pulsar args - Passes timeout through dispatcher.py: - Accepts auth and timeout parameters - Passes both to DispatcherManager — fixes the missing auth argument that would have crashed on startup The remote end's requests now go through the same IAM authentication path as api-gateway. Token validation, workspace resolution, and permissions all work identically regardless of which direction initiated the connection. Fixed tests — the test now passes auth and timeout to MessageDispatcher and verifies they're forwarded to DispatcherManager. Update rev gateway dispatcher to align with IAM. A "token" parameter must be passed with each message. Fix websocket relay to align with rev-gateway changes, conforms to the api-gateway protocol. --- dev-tools/tests/relay/websocket_relay.py | 332 +++++---- .../test_dispatcher_semaphore.py | 25 +- tests/unit/test_rev_gateway/__init__.py | 0 .../unit/test_rev_gateway/test_dispatcher.py | 394 +++++------ .../test_rev_gateway_service.py | 661 ++++++++---------- .../trustgraph/rev_gateway/dispatcher.py | 162 +++-- .../trustgraph/rev_gateway/service.py | 205 +++--- 7 files changed, 914 insertions(+), 865 deletions(-) create mode 100644 tests/unit/test_rev_gateway/__init__.py diff --git a/dev-tools/tests/relay/websocket_relay.py b/dev-tools/tests/relay/websocket_relay.py index d537f7da..6495ee49 100644 --- a/dev-tools/tests/relay/websocket_relay.py +++ b/dev-tools/tests/relay/websocket_relay.py @@ -3,208 +3,278 @@ WebSocket Relay Test Harness This script creates a relay server with two WebSocket endpoints: -- /in - for test clients to connect to -- /out - for reverse gateway to connect to +- /in - for test clients to connect to (speaks api-gateway protocol) +- /out - for reverse gateway to connect to (speaks rev-gateway protocol) -Messages are bidirectionally relayed between the two connections. +Clients on /in authenticate with a first-frame auth message: + {"type": "auth", "token": "..."} + +The relay stores the token and injects it into each subsequent message +before forwarding to /out. Responses from /out are forwarded back to +the originating /in connection unchanged. Usage: python websocket_relay.py [--port PORT] [--host HOST] """ import asyncio +import json import logging import argparse from aiohttp import web, WSMsgType -import weakref -from typing import Optional, Set +from typing import Dict, Optional -# Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("websocket_relay") + +class InConnection: + def __init__(self, ws, conn_id): + self.ws = ws + self.conn_id = conn_id + self.token: Optional[str] = None + self.authenticated = False + + class WebSocketRelay: - """WebSocket relay that forwards messages between 'in' and 'out' connections""" - + def __init__(self): - self.in_connections: Set = weakref.WeakSet() - self.out_connections: Set = weakref.WeakSet() - + self.in_connections: Dict[str, InConnection] = {} + self.out_connections: set = set() + self._conn_counter = 0 + + def _next_conn_id(self): + self._conn_counter += 1 + return f"conn-{self._conn_counter}" + async def handle_in_connection(self, request): - """Handle incoming connections on /in endpoint""" ws = web.WebSocketResponse() await ws.prepare(request) - - self.in_connections.add(ws) - logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + + conn_id = self._next_conn_id() + conn = InConnection(ws, conn_id) + self.in_connections[conn_id] = conn + logger.info( + f"New 'in' connection {conn_id}. " + f"Total in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + try: async for msg in ws: if msg.type == WSMsgType.TEXT: - data = msg.data - logger.info(f"IN → OUT: {data}") - await self._forward_to_out(data) - elif msg.type == WSMsgType.BINARY: - data = msg.data - logger.info(f"IN → OUT: {len(data)} bytes (binary)") - await self._forward_to_out(data, binary=True) + await self._handle_in_message(conn, msg.data) elif msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket error on 'in' connection: {ws.exception()}") + logger.error( + f"WebSocket error on 'in' connection " + f"{conn_id}: {ws.exception()}" + ) break else: break except Exception as e: - logger.error(f"Error in 'in' connection handler: {e}") + logger.error( + f"Error in 'in' connection {conn_id}: {e}" + ) finally: - logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + del self.in_connections[conn_id] + logger.info( + f"'in' connection {conn_id} closed. " + f"Remaining in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + return ws - - async def handle_out_connection(self, request): - """Handle outgoing connections on /out endpoint""" - ws = web.WebSocketResponse() - await ws.prepare(request) - - self.out_connections.add(ws) - logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + + async def _handle_in_message(self, conn, data): try: - async for msg in ws: - if msg.type == WSMsgType.TEXT: - data = msg.data - logger.info(f"OUT → IN: {data}") - await self._forward_to_in(data) - elif msg.type == WSMsgType.BINARY: - data = msg.data - logger.info(f"OUT → IN: {len(data)} bytes (binary)") - await self._forward_to_in(data, binary=True) - elif msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket error on 'out' connection: {ws.exception()}") - break - else: - break - except Exception as e: - logger.error(f"Error in 'out' connection handler: {e}") - finally: - logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") - - return ws - - async def _forward_to_out(self, data, binary=False): - """Forward message from 'in' to all 'out' connections""" - if not self.out_connections: - logger.warning("No 'out' connections available to forward message") + message = json.loads(data) + except json.JSONDecodeError: + logger.warning( + f"{conn.conn_id}: received non-JSON message" + ) return - - closed_connections = [] + + if isinstance(message, dict) and message.get("type") == "auth": + conn.token = message.get("token", "") + conn.authenticated = True + logger.info(f"{conn.conn_id}: authenticated") + await conn.ws.send_json({ + "type": "auth-ok", + "workspace": "relayed", + }) + return + + if not conn.authenticated: + await conn.ws.send_json({ + "error": { + "message": "auth required", + "type": "auth-required", + }, + "complete": True, + }) + return + + message["token"] = conn.token + message["_relay_conn"] = conn.conn_id + + forwarded = json.dumps(message) + logger.info(f"IN {conn.conn_id} → OUT: {forwarded}") + await self._forward_to_out(forwarded) + + async def handle_out_connection(self, request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.out_connections.add(ws) + logger.info( + f"New 'out' connection. " + f"Total in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + await self._handle_out_message(msg.data) + elif msg.type == WSMsgType.ERROR: + logger.error( + f"WebSocket error on 'out' connection: " + f"{ws.exception()}" + ) + break + else: + break + except Exception as e: + logger.error(f"Error in 'out' connection: {e}") + finally: + self.out_connections.discard(ws) + logger.info( + f"'out' connection closed. " + f"Remaining in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + + return ws + + async def _handle_out_message(self, data): + try: + message = json.loads(data) + except json.JSONDecodeError: + logger.warning("OUT: received non-JSON message") + return + + conn_id = message.pop("_relay_conn", None) + + forwarded = json.dumps(message) + logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}") + + if conn_id and conn_id in self.in_connections: + conn = self.in_connections[conn_id] + try: + if not conn.ws.closed: + await conn.ws.send_str(forwarded) + except Exception as e: + logger.error( + f"Error forwarding to 'in' {conn_id}: {e}" + ) + else: + await self._broadcast_to_in(forwarded) + + async def _broadcast_to_in(self, data): + closed = [] + for conn_id, conn in list(self.in_connections.items()): + try: + if conn.ws.closed: + closed.append(conn_id) + continue + await conn.ws.send_str(data) + except Exception as e: + logger.error( + f"Error broadcasting to 'in' {conn_id}: {e}" + ) + closed.append(conn_id) + for conn_id in closed: + self.in_connections.pop(conn_id, None) + + async def _forward_to_out(self, data): + closed = [] for ws in list(self.out_connections): try: if ws.closed: - closed_connections.append(ws) + closed.append(ws) continue - - if binary: - await ws.send_bytes(data) - else: - await ws.send_str(data) + await ws.send_str(data) except Exception as e: - logger.error(f"Error forwarding to 'out' connection: {e}") - closed_connections.append(ws) - - # Clean up closed connections - for ws in closed_connections: - if ws in self.out_connections: - self.out_connections.discard(ws) - - async def _forward_to_in(self, data, binary=False): - """Forward message from 'out' to all 'in' connections""" - if not self.in_connections: - logger.warning("No 'in' connections available to forward message") - return - - closed_connections = [] - for ws in list(self.in_connections): - try: - if ws.closed: - closed_connections.append(ws) - continue - - if binary: - await ws.send_bytes(data) - else: - await ws.send_str(data) - except Exception as e: - logger.error(f"Error forwarding to 'in' connection: {e}") - closed_connections.append(ws) - - # Clean up closed connections - for ws in closed_connections: - if ws in self.in_connections: - self.in_connections.discard(ws) + logger.error(f"Error forwarding to 'out': {e}") + closed.append(ws) + for ws in closed: + self.out_connections.discard(ws) + async def create_app(relay): - """Create the web application with routes""" app = web.Application() - - # Add routes - app.router.add_get('/in', relay.handle_in_connection) + + app.router.add_get('/in/api/v1/socket', relay.handle_in_connection) app.router.add_get('/out', relay.handle_out_connection) - - # Add a simple status endpoint + async def status(request): - status_info = { + return web.json_response({ 'in_connections': len(relay.in_connections), 'out_connections': len(relay.out_connections), - 'status': 'running' - } - return web.json_response(status_info) - + 'status': 'running', + }) + app.router.add_get('/status', status) - app.router.add_get('/', status) # Root also shows status - + app.router.add_get('/', status) + return app + def main(): parser = argparse.ArgumentParser( description="WebSocket Relay Test Harness" ) parser.add_argument( - '--host', + '--host', default='localhost', - help='Host to bind to (default: localhost)' + help='Host to bind to (default: localhost)', ) parser.add_argument( - '--port', - type=int, + '--port', + type=int, default=8080, - help='Port to bind to (default: 8080)' + help='Port to bind to (default: 8080)', ) parser.add_argument( '--verbose', '-v', action='store_true', - help='Enable verbose logging' + help='Enable verbose logging', ) - + args = parser.parse_args() - + if args.verbose: logging.getLogger().setLevel(logging.DEBUG) - + relay = WebSocketRelay() - + print(f"Starting WebSocket Relay on {args.host}:{args.port}") - print(f" 'in' endpoint: ws://{args.host}:{args.port}/in") + print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket") print(f" 'out' endpoint: ws://{args.host}:{args.port}/out") print(f" Status: http://{args.host}:{args.port}/status") print() - print("Usage:") - print(f" Test client connects to: ws://{args.host}:{args.port}/in") - print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out") - + print("Client protocol (same as api-gateway):") + print(' 1. Connect to /in/api/v1/socket') + print(' 2. Send: {"type": "auth", "token": "tg_..."}') + print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}') + print(' 4. Send requests as normal') + web.run_app(create_app(relay), host=args.host, port=args.port) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/unit/test_concurrency/test_dispatcher_semaphore.py b/tests/unit/test_concurrency/test_dispatcher_semaphore.py index 6a1ae8ab..a5374678 100644 --- a/tests/unit/test_concurrency/test_dispatcher_semaphore.py +++ b/tests/unit/test_concurrency/test_dispatcher_semaphore.py @@ -25,16 +25,17 @@ class TestSemaphoreEnforcement: max_concurrent = 0 processing_event = asyncio.Event() - async def slow_process(message): + async def slow_process(message, sender): nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count) await asyncio.sleep(0.05) concurrent_count -= 1 - return {"id": message.get("id"), "response": {"ok": True}} dispatcher._process_message = slow_process + sender = AsyncMock() + # Launch more tasks than max_workers messages = [ {"id": f"msg-{i}", "service": "test", "request": {}} @@ -42,7 +43,7 @@ class TestSemaphoreEnforcement: ] tasks = [ - asyncio.create_task(dispatcher.handle_message(m)) + asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages ] @@ -66,17 +67,17 @@ class TestSemaphoreEnforcement: original_process = dispatcher._process_message - async def tracking_process(message): + async def tracking_process(message, sender): nonlocal task_was_tracked # During processing, our task should be in active_tasks if len(dispatcher.active_tasks) > 0: task_was_tracked = True - return {"id": message.get("id"), "response": {"ok": True}} dispatcher._process_message = tracking_process await dispatcher.handle_message( - {"id": "test", "service": "test", "request": {}} + {"id": "test", "service": "test", "request": {}}, + AsyncMock(), ) assert task_was_tracked @@ -88,7 +89,7 @@ class TestSemaphoreEnforcement: """Semaphore should be released even if processing raises.""" dispatcher = MessageDispatcher(max_workers=2) - async def failing_process(message): + async def failing_process(message, sender): raise RuntimeError("process failed") dispatcher._process_message = failing_process @@ -96,7 +97,8 @@ class TestSemaphoreEnforcement: # Should not deadlock — semaphore must be released on error with pytest.raises(RuntimeError): await dispatcher.handle_message( - {"id": "test", "service": "test", "request": {}} + {"id": "test", "service": "test", "request": {}}, + AsyncMock(), ) # Semaphore should be back at max @@ -109,17 +111,18 @@ class TestSemaphoreEnforcement: order = [] - async def ordered_process(message): + async def ordered_process(message, sender): msg_id = message["id"] order.append(f"start-{msg_id}") await asyncio.sleep(0.02) order.append(f"end-{msg_id}") - return {"id": msg_id, "response": {"ok": True}} dispatcher._process_message = ordered_process + sender = AsyncMock() + messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)] - tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages] + tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages] await asyncio.gather(*tasks) # With semaphore=1, each message should complete before next starts diff --git a/tests/unit/test_rev_gateway/__init__.py b/tests/unit/test_rev_gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_rev_gateway/test_dispatcher.py b/tests/unit/test_rev_gateway/test_dispatcher.py index 2a9c8df0..6df786cf 100644 --- a/tests/unit/test_rev_gateway/test_dispatcher.py +++ b/tests/unit/test_rev_gateway/test_dispatcher.py @@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher """ import pytest -from unittest.mock import MagicMock, AsyncMock, patch +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch, ANY -from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher - - -class TestWebSocketResponder: - """Test cases for WebSocketResponder class""" - - def test_websocket_responder_initialization(self): - """Test WebSocketResponder initialization""" - responder = WebSocketResponder() - - assert responder.response is None - assert responder.completed is False - - @pytest.mark.asyncio - async def test_websocket_responder_send_method(self): - """Test WebSocketResponder send method""" - responder = WebSocketResponder() - - test_response = {"data": "test response"} - - # Call send method - await responder.send(test_response) - - # Verify response was stored - assert responder.response == test_response - - @pytest.mark.asyncio - async def test_websocket_responder_call_method(self): - """Test WebSocketResponder __call__ method""" - responder = WebSocketResponder() - - test_response = {"result": "success"} - test_completed = True - - # Call the responder - await responder(test_response, test_completed) - - # Verify response and completed status were set - assert responder.response == test_response - assert responder.completed == test_completed - - @pytest.mark.asyncio - async def test_websocket_responder_call_method_with_false_completion(self): - """Test WebSocketResponder __call__ method with incomplete response""" - responder = WebSocketResponder() - - test_response = {"partial": "data"} - test_completed = False - - # Call the responder - await responder(test_response, test_completed) - - # Verify response was set and completed is True (since send() always sets completed=True) - assert responder.response == test_response - assert responder.completed is True +from trustgraph.rev_gateway.dispatcher import MessageDispatcher class TestMessageDispatcher: """Test cases for MessageDispatcher class""" def test_message_dispatcher_initialization_with_defaults(self): - """Test MessageDispatcher initialization with default parameters""" dispatcher = MessageDispatcher() - + assert dispatcher.max_workers == 10 assert dispatcher.semaphore._value == 10 assert dispatcher.active_tasks == set() assert dispatcher.backend is None + assert dispatcher.auth is None assert dispatcher.dispatcher_manager is None assert len(dispatcher.service_mapping) > 0 def test_message_dispatcher_initialization_with_custom_workers(self): - """Test MessageDispatcher initialization with custom max_workers""" dispatcher = MessageDispatcher(max_workers=5) - + assert dispatcher.max_workers == 5 assert dispatcher.semaphore._value == 5 @patch('trustgraph.rev_gateway.dispatcher.DispatcherManager') - def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager): - """Test MessageDispatcher initialization with pulsar_client and config_receiver""" + def test_message_dispatcher_initialization_with_backend( + self, mock_dispatcher_manager, + ): mock_backend = MagicMock() mock_config_receiver = MagicMock() + mock_auth = MagicMock() mock_dispatcher_instance = MagicMock() mock_dispatcher_manager.return_value = mock_dispatcher_instance - + dispatcher = MessageDispatcher( max_workers=8, config_receiver=mock_config_receiver, - backend=mock_backend + backend=mock_backend, + auth=mock_auth, + timeout=300, ) - + assert dispatcher.max_workers == 8 assert dispatcher.backend == mock_backend + assert dispatcher.auth == mock_auth assert dispatcher.dispatcher_manager == mock_dispatcher_instance mock_dispatcher_manager.assert_called_once_with( - mock_backend, mock_config_receiver, prefix="rev-gateway" + mock_backend, mock_config_receiver, + auth=mock_auth, prefix="rev-gateway", timeout=300, ) def test_message_dispatcher_service_mapping(self): - """Test MessageDispatcher service mapping contains expected services""" dispatcher = MessageDispatcher() - + expected_services = [ "text-completion", "graph-rag", "agent", "embeddings", "graph-embeddings", "triples", "document-load", "text-load", - "flow", "knowledge", "config", "librarian", "document-rag" + "flow", "knowledge", "config", "librarian", "document-rag", ] - + for service in expected_services: assert service in dispatcher.service_mapping - - # Test specific mappings - assert dispatcher.service_mapping["text-completion"] == "text-completion" + assert dispatcher.service_mapping["document-load"] == "document" assert dispatcher.service_mapping["text-load"] == "text-document" @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_without_dispatcher_manager(self): - """Test MessageDispatcher handle_message without dispatcher manager""" + async def test_handle_message_without_dispatcher_manager(self): dispatcher = MessageDispatcher() - - test_message = { - "id": "test-123", - "service": "test-service", - "request": {"data": "test"} - } - - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-123" - assert "error" in result["response"] - assert "DispatcherManager not available" in result["response"]["error"] + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="default") + ) + + sender = AsyncMock() + + await dispatcher.handle_message( + {"id": "test-1", "service": "test", "request": {}}, + sender, + ) + + sender.assert_called_once() + sent = sender.call_args[0][0] + assert sent["id"] == "test-1" + assert sent["error"]["message"] == "DispatcherManager not available" + assert sent["error"]["type"] == "error" + assert sent["complete"] is True @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_with_exception(self): - """Test MessageDispatcher handle_message with exception during processing""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error")) - + async def test_handle_message_auth_failure(self): dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-456", - "service": "text-completion", - "request": {"prompt": "test"} - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-456" - assert "error" in result["response"] - assert "Test error" in result["response"]["error"] + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + side_effect=Exception("auth failure") + ) + dispatcher.dispatcher_manager = MagicMock() + + sender = AsyncMock() + + await dispatcher.handle_message( + {"id": "test-2", "token": "bad", "service": "test", "request": {}}, + sender, + ) + + sender.assert_called_once() + sent = sender.call_args[0][0] + assert sent["id"] == "test-2" + assert "auth failure" in sent["error"]["message"] + assert sent["complete"] is True @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_global_service(self): - """Test MessageDispatcher handle_message with global service""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_global_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = True - mock_responder.response = {"result": "success"} - + async def test_handle_message_global_service(self): + mock_dm = MagicMock() + mock_dm.invoke_global_service = AsyncMock() + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-789", - "service": "text-completion", - "request": {"prompt": "hello"} - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-789" - assert result["response"] == {"result": "success"} - mock_dispatcher_manager.invoke_global_service.assert_called_once() + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws1") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', + {"text-completion": True}, + ): + await dispatcher.handle_message( + { + "id": "test-3", + "token": "tg_key", + "service": "text-completion", + "request": {"prompt": "hello"}, + }, + sender, + ) + + mock_dm.invoke_global_service.assert_called_once() + args, kwargs = mock_dm.invoke_global_service.call_args + assert args[0] == {"prompt": "hello"} + assert args[2] == "text-completion" + assert kwargs["workspace"] == "ws1" @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_flow_service(self): - """Test MessageDispatcher handle_message with flow service""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_flow_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = True - mock_responder.response = {"data": "flow_result"} - + async def test_handle_message_flow_service(self): + mock_dm = MagicMock() + mock_dm.invoke_flow_service = AsyncMock() + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-flow-123", - "service": "document-rag", - "request": {"query": "test"}, - "flow": "custom-flow" - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-flow-123" - assert result["response"] == {"data": "flow_result"} - mock_dispatcher_manager.invoke_flow_service.assert_called_once_with( - {"query": "test"}, mock_responder, "custom-flow", "document-rag" + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws2") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', {}, + ): + await dispatcher.handle_message( + { + "id": "test-4", + "token": "tg_key", + "service": "document-rag", + "request": {"query": "test"}, + "flow": "my-flow", + }, + sender, + ) + + mock_dm.invoke_flow_service.assert_called_once_with( + {"query": "test"}, ANY, "ws2", "my-flow", "document-rag", ) @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_incomplete_response(self): - """Test MessageDispatcher handle_message with incomplete response""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_flow_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = False - mock_responder.response = None - + async def test_handle_message_responder_sends_frames(self): + mock_dm = MagicMock() + + async def fake_invoke(data, responder, svc, workspace=None): + await responder({"partial": 1}, False) + await responder({"partial": 2}, True) + + mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke) + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-incomplete", - "service": "agent", - "request": {"input": "test"} + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws1") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', + {"text-completion": True}, + ): + await dispatcher.handle_message( + { + "id": "test-5", + "token": "tg_key", + "service": "text-completion", + "request": {"prompt": "hi"}, + }, + sender, + ) + + assert sender.call_count == 2 + first = sender.call_args_list[0][0][0] + second = sender.call_args_list[1][0][0] + + assert first == { + "id": "test-5", "response": {"partial": 1}, "complete": False, + } + assert second == { + "id": "test-5", "response": {"partial": 2}, "complete": True, } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-incomplete" - assert result["response"] == {"error": "No response received"} @pytest.mark.asyncio - async def test_message_dispatcher_shutdown(self): - """Test MessageDispatcher shutdown method""" - import asyncio - + async def test_handle_message_workspace_from_identity(self): + mock_dm = MagicMock() + mock_dm.invoke_flow_service = AsyncMock() + dispatcher = MessageDispatcher() - - # Create actual async tasks + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="derived-ws") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', {}, + ): + await dispatcher.handle_message( + { + "id": "test-6", + "token": "tg_key", + "service": "agent", + "request": {"question": "test"}, + "flow": "default", + }, + sender, + ) + + args = mock_dm.invoke_flow_service.call_args[0] + assert args[2] == "derived-ws" + + @pytest.mark.asyncio + async def test_shutdown(self): + dispatcher = MessageDispatcher() + async def dummy_task(): await asyncio.sleep(0.01) - return "done" - + task1 = asyncio.create_task(dummy_task()) task2 = asyncio.create_task(dummy_task()) dispatcher.active_tasks = {task1, task2} - - # Call shutdown + await dispatcher.shutdown() - - # Verify tasks were completed + assert task1.done() assert task2.done() - assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed @pytest.mark.asyncio - async def test_message_dispatcher_shutdown_with_no_tasks(self): - """Test MessageDispatcher shutdown with no active tasks""" + async def test_shutdown_with_no_tasks(self): dispatcher = MessageDispatcher() - - # Call shutdown with no active tasks + await dispatcher.shutdown() - - # Should complete without error - assert dispatcher.active_tasks == set() \ No newline at end of file + + assert dispatcher.active_tasks == set() diff --git a/tests/unit/test_rev_gateway/test_rev_gateway_service.py b/tests/unit/test_rev_gateway/test_rev_gateway_service.py index 23aff18e..e2d1045d 100644 --- a/tests/unit/test_rev_gateway/test_rev_gateway_service.py +++ b/tests/unit/test_rev_gateway/test_rev_gateway_service.py @@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock from aiohttp import WSMsgType, ClientWebSocketResponse import json -from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run +from trustgraph.rev_gateway.service import ReverseGateway, run + + +MOCK_PATCHES = [ + 'trustgraph.rev_gateway.service.IamAuth', + 'trustgraph.rev_gateway.service.ConfigReceiver', + 'trustgraph.rev_gateway.service.MessageDispatcher', + 'trustgraph.rev_gateway.service.get_pubsub', +] + + +def make_gateway(**overrides): + config = {"websocket_uri": "ws://localhost:7650/out"} + config.update(overrides) + return ReverseGateway(**config) class TestReverseGateway: """Test cases for ReverseGateway class""" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with default parameters""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_defaults( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + assert gateway.websocket_uri == "ws://localhost:7650/out" assert gateway.host == "localhost" assert gateway.port == 7650 @@ -33,25 +49,22 @@ class TestReverseGateway: assert gateway.max_workers == 10 assert gateway.running is False assert gateway.reconnect_delay == 3.0 - assert gateway.pulsar_host == "pulsar://pulsar:6650" - assert gateway.pulsar_api_key is None - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with custom parameters""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway( + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_custom_params( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway( websocket_uri="wss://example.com:8080/websocket", max_workers=20, - pulsar_host="pulsar://custom:6650", - pulsar_api_key="test-key", - pulsar_listener="test-listener" ) - + assert gateway.websocket_uri == "wss://example.com:8080/websocket" assert gateway.host == "example.com" assert gateway.port == 8080 @@ -59,340 +72,360 @@ class TestReverseGateway: assert gateway.path == "/websocket" assert gateway.url == "wss://example.com:8080/websocket" assert gateway.max_workers == 20 - assert gateway.pulsar_host == "pulsar://custom:6650" - assert gateway.pulsar_api_key == "test-key" - assert gateway.pulsar_listener == "test-listener" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with WebSocket URI missing path""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway(websocket_uri="ws://example.com") - + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_with_missing_path( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway(websocket_uri="ws://example.com") + assert gateway.path == "/ws" assert gateway.url == "ws://example.com/ws" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with invalid WebSocket scheme""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_invalid_scheme( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"): - ReverseGateway(websocket_uri="http://example.com") + make_gateway(websocket_uri="http://example.com") - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with missing hostname""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_missing_hostname( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): with pytest.raises(ValueError, match="WebSocket URI must include hostname"): - ReverseGateway(websocket_uri="ws://") + make_gateway(websocket_uri="ws://") - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway creates backend with authentication""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_iam_auth_created( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): mock_backend = MagicMock() mock_get_pubsub.return_value = mock_backend - gateway = ReverseGateway( - pulsar_api_key="test-key", - pulsar_listener="test-listener" + gateway = make_gateway(id="test-rev-gw") + + mock_iam_auth.assert_called_once_with( + backend=mock_backend, + id="test-rev-gw", ) - # Verify get_pubsub was called with the correct parameters - mock_get_pubsub.assert_called_once_with( - pulsar_host="pulsar://pulsar:6650", - pulsar_api_key="test-key", - pulsar_listener="test-listener" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_config_receiver_gets_auth( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend + mock_auth_instance = MagicMock() + mock_iam_auth.return_value = mock_auth_instance + + gateway = make_gateway() + + mock_config_receiver.assert_called_once_with( + mock_backend, auth=mock_auth_instance, ) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway successful connection""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_connect_success( + self, mock_session_class, mock_get_pubsub, + mock_dispatcher, mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + mock_session = AsyncMock() mock_ws = AsyncMock() mock_session.ws_connect.return_value = mock_ws mock_session_class.return_value = mock_session - - gateway = ReverseGateway() - + + gateway = make_gateway() + result = await gateway.connect() - + assert result is True assert gateway.session == mock_session assert gateway.ws == mock_ws mock_session.ws_connect.assert_called_once_with(gateway.url) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway connection failure""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_connect_failure( + self, mock_session_class, mock_get_pubsub, + mock_dispatcher, mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + mock_session = AsyncMock() mock_session.ws_connect.side_effect = Exception("Connection failed") mock_session_class.return_value = mock_session - - gateway = ReverseGateway() - + + gateway = make_gateway() + result = await gateway.connect() - + assert result is False - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway disconnect""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock websocket and session + async def test_reverse_gateway_disconnect( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = False mock_session = AsyncMock() mock_session.closed = False - + gateway.ws = mock_ws gateway.session = mock_session - + await gateway.disconnect() - + mock_ws.close.assert_called_once() mock_session.close.assert_called_once() assert gateway.ws is None assert gateway.session is None - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway send message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock websocket + async def test_reverse_gateway_send_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - + test_message = {"id": "test", "data": "hello"} - + await gateway.send_message(test_message) - + mock_ws.send_str.assert_called_once_with(json.dumps(test_message)) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway send message with closed connection""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock closed websocket + async def test_reverse_gateway_send_message_closed_connection( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = True gateway.ws = mock_ws - + test_message = {"id": "test", "data": "hello"} - + await gateway.send_message(test_message) - - # Should not call send_str on closed connection + mock_ws.send_str.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway handle message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - mock_dispatcher_instance = AsyncMock() - mock_dispatcher_instance.handle_message.return_value = {"response": "success"} - mock_dispatcher.return_value = mock_dispatcher_instance - - gateway = ReverseGateway() - - # Mock send_message - gateway.send_message = AsyncMock() - - test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}' - - await gateway.handle_message(test_message) - - mock_dispatcher_instance.handle_message.assert_called_once_with({ - "id": "test", - "service": "test-service", - "request": {"data": "test"} - }) - gateway.send_message.assert_called_once_with({"response": "success"}) + async def test_reverse_gateway_handle_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + mock_dispatcher_instance = AsyncMock() + mock_dispatcher.return_value = mock_dispatcher_instance + + gateway = make_gateway() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - @pytest.mark.asyncio - async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway handle message with invalid JSON""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock send_message gateway.send_message = AsyncMock() - - test_message = 'invalid json' - - # Should not raise exception + + test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}' + await gateway.handle_message(test_message) - - # Should not call send_message due to error + + mock_dispatcher_instance.handle_message.assert_called_once_with( + { + "id": "test", + "service": "test-service", + "request": {"data": "test"}, + }, + gateway.send_message, + ) + + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + @pytest.mark.asyncio + async def test_reverse_gateway_handle_message_invalid_json( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + + gateway.send_message = AsyncMock() + + await gateway.handle_message('invalid json') + gateway.send_message.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with text message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_text_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.TEXT mock_msg.data = '{"test": "message"}' - - # Mock receive to return message once, then raise exception to stop loop + mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] - - # listen() catches exceptions and breaks, so no exception should be raised + await gateway.listen() - + gateway.handle_message.assert_called_once_with('{"test": "message"}') - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with binary message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_binary_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.BINARY mock_msg.data = b'{"test": "binary"}' - - # Mock receive to return message once, then raise exception to stop loop + mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] - - # listen() catches exceptions and breaks, so no exception should be raised + await gateway.listen() - + gateway.handle_message.assert_called_once_with('{"test": "binary"}') - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with close message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_close_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.CLOSE - - # Mock receive to return close message + mock_ws.receive.return_value = mock_msg - + await gateway.listen() - - # Should not call handle_message for close message + gateway.handle_message.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway shutdown""" + async def test_reverse_gateway_shutdown( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): mock_backend = MagicMock() mock_get_pubsub.return_value = mock_backend mock_dispatcher_instance = AsyncMock() mock_dispatcher.return_value = mock_dispatcher_instance - gateway = ReverseGateway() + gateway = make_gateway() gateway.running = True - # Mock disconnect gateway.disconnect = AsyncMock() await gateway.shutdown() @@ -402,46 +435,50 @@ class TestReverseGateway: gateway.disconnect.assert_called_once() mock_backend.close.assert_called_once() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway stop""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_stop( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - + gateway.stop() - + assert gateway.running is False class TestReverseGatewayRun: """Test cases for ReverseGateway run method""" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway run method with successful connect/listen cycle""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_run_successful_cycle( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + mock_auth_instance = AsyncMock() + mock_iam_auth.return_value = mock_auth_instance + mock_config_receiver_instance = AsyncMock() mock_config_receiver.return_value = mock_config_receiver_instance - - gateway = ReverseGateway() - - # Mock methods - gateway.connect = AsyncMock(return_value=True) + + gateway = make_gateway() + gateway.listen = AsyncMock() gateway.disconnect = AsyncMock() gateway.shutdown = AsyncMock() - - # Stop after one iteration + call_count = 0 async def mock_connect(): nonlocal call_count @@ -451,91 +488,13 @@ class TestReverseGatewayRun: else: gateway.running = False return False - + gateway.connect = mock_connect - + await gateway.run() - + + mock_auth_instance.start.assert_called_once() mock_config_receiver_instance.start.assert_called_once() gateway.listen.assert_called_once() - # disconnect is called twice: once in the main loop, once in shutdown assert gateway.disconnect.call_count == 2 gateway.shutdown.assert_called_once() - - -class TestReverseGatewayArgs: - """Test cases for argument parsing and run function""" - - def test_parse_args_defaults(self): - """Test parse_args with default values""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = ['reverse-gateway'] - - try: - args = parse_args() - - assert args.websocket_uri is None - assert args.max_workers == 10 - assert args.pulsar_host is None - assert args.pulsar_api_key is None - assert args.pulsar_listener is None - finally: - sys.argv = original_argv - - def test_parse_args_custom_values(self): - """Test parse_args with custom values""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = [ - 'reverse-gateway', - '--websocket-uri', 'ws://custom:8080/ws', - '--max-workers', '20', - '--pulsar-host', 'pulsar://custom:6650', - '--pulsar-api-key', 'test-key', - '--pulsar-listener', 'test-listener' - ] - - try: - args = parse_args() - - assert args.websocket_uri == 'ws://custom:8080/ws' - assert args.max_workers == 20 - assert args.pulsar_host == 'pulsar://custom:6650' - assert args.pulsar_api_key == 'test-key' - assert args.pulsar_listener == 'test-listener' - finally: - sys.argv = original_argv - - @patch('trustgraph.rev_gateway.service.ReverseGateway') - @patch('asyncio.run') - def test_run_function(self, mock_asyncio_run, mock_gateway_class): - """Test run function""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = ['reverse-gateway', '--max-workers', '15'] - - try: - mock_gateway_instance = MagicMock() - mock_gateway_instance.url = "ws://localhost:7650/out" - mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650" - mock_gateway_class.return_value = mock_gateway_instance - - run() - - mock_gateway_class.assert_called_once_with( - websocket_uri=None, - max_workers=15, - pulsar_host=None, - pulsar_api_key=None, - pulsar_listener=None - ) - mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run()) - finally: - sys.argv = original_argv \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py index 986558ec..5e71a0e5 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py +++ b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py @@ -1,130 +1,140 @@ import asyncio import logging import uuid -from typing import Dict, Any, Optional -from trustgraph.messaging import TranslatorRegistry +from typing import Dict, Any, Optional, Callable, Awaitable from ..gateway.dispatch.manager import DispatcherManager logger = logging.getLogger("dispatcher") -logger.setLevel(logging.INFO) -class WebSocketResponder: - """Simple responder that captures response for websocket return""" - def __init__(self): - self.response = None - self.completed = False - - async def send(self, data): - """Capture the response data""" - self.response = data - self.completed = True - - async def __call__(self, data, final=False): - """Make the responder callable for compatibility with requestor""" - await self.send(data) - if final: - self.completed = True + +class _TokenShim: + def __init__(self, token): + self.headers = ( + {"Authorization": f"Bearer {token}"} if token else {} + ) + class MessageDispatcher: - def __init__(self, max_workers: int = 10, config_receiver=None, backend=None): + def __init__(self, max_workers=10, config_receiver=None, backend=None, + auth=None, timeout=120): self.max_workers = max_workers self.semaphore = asyncio.Semaphore(max_workers) self.active_tasks = set() self.backend = backend + self.auth = auth - # Use DispatcherManager for flow and service management - if backend and config_receiver: - self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway") + if backend and config_receiver and auth: + self.dispatcher_manager = DispatcherManager( + backend, config_receiver, + auth=auth, + prefix="rev-gateway", + timeout=timeout, + ) else: self.dispatcher_manager = None - logger.warning("No backend or config_receiver provided - using fallback mode") - - # Service name mapping from websocket protocol to translator registry + logger.warning( + "Missing backend, config_receiver, or auth " + "— using fallback mode" + ) + self.service_mapping = { "text-completion": "text-completion", - "graph-rag": "graph-rag", + "graph-rag": "graph-rag", "agent": "agent", "embeddings": "embeddings", "graph-embeddings": "graph-embeddings", "triples": "triples", "document-load": "document", - "text-load": "text-document", + "text-load": "text-document", "flow": "flow", "knowledge": "knowledge", "config": "config", "librarian": "librarian", - "document-rag": "document-rag" + "document-rag": "document-rag", } - - async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]: + + async def handle_message( + self, message: Dict[Any, Any], + sender: Callable[[dict], Awaitable[None]], + ): async with self.semaphore: - task = asyncio.create_task(self._process_message(message)) + task = asyncio.create_task( + self._process_message(message, sender) + ) self.active_tasks.add(task) - + try: - result = await task - return result + await task finally: self.active_tasks.discard(task) - - async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]: + + async def _authenticate(self, token): + if not self.auth: + raise RuntimeError("Auth not configured") + return await self.auth.authenticate(_TokenShim(token)) + + async def _process_message( + self, message: Dict[Any, Any], + sender: Callable[[dict], Awaitable[None]], + ): request_id = message.get('id', str(uuid.uuid4())) service = message.get('service') request_data = message.get('request', {}) - flow_id = message.get('flow', 'default') # Default flow - - logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}") - + token = message.get('token', '') + flow_id = message.get('flow', 'default') + + logger.info( + f"Processing message {request_id} for service " + f"{service} on flow {flow_id}" + ) + try: if not self.dispatcher_manager: - raise RuntimeError("DispatcherManager not available - backend and config_receiver required") - - # Use DispatcherManager for flow-based processing - responder = WebSocketResponder() - - # Map websocket service name to dispatcher service name + raise RuntimeError( + "DispatcherManager not available" + ) + + identity = await self._authenticate(token) + workspace = identity.workspace + + async def responder(resp, fin): + await sender({ + "id": request_id, + "response": resp, + "complete": fin, + }) + dispatcher_service = self.service_mapping.get(service, service) - - # Check if this is a global service or flow service + from ..gateway.dispatch.manager import global_dispatchers if dispatcher_service in global_dispatchers: - # Use global service dispatcher await self.dispatcher_manager.invoke_global_service( - request_data, responder, dispatcher_service + request_data, responder, dispatcher_service, + workspace=workspace, ) else: - # Use DispatcherManager to process the request through Pulsar queues await self.dispatcher_manager.invoke_flow_service( - request_data, responder, flow_id, dispatcher_service + request_data, responder, workspace, flow_id, + dispatcher_service, ) - - # Get the response from the responder - if responder.completed: - response_data = responder.response - else: - response_data = {'error': 'No response received'} - - response = { - 'id': request_id, - 'response': response_data - } - + except Exception as e: logger.error(f"Error processing message {request_id}: {e}") - response = { - 'id': request_id, - 'response': {'error': str(e)} - } - + await sender({ + "id": request_id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) + logger.info(f"Completed processing message {request_id}") - return response - - + async def shutdown(self): if self.active_tasks: - logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete") + logger.info( + f"Waiting for {len(self.active_tasks)} active " + f"tasks to complete" + ) await asyncio.gather(*self.active_tasks, return_exceptions=True) - - # DispatcherManager handles its own cleanup + logger.info("Dispatcher shutdown complete") diff --git a/trustgraph-flow/trustgraph/rev_gateway/service.py b/trustgraph-flow/trustgraph/rev_gateway/service.py index cc905172..39180e2c 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/service.py +++ b/trustgraph-flow/trustgraph/rev_gateway/service.py @@ -1,3 +1,9 @@ +""" +Reverse gateway. Initiates outbound WebSocket connections to a remote +relay and dispatches incoming requests through the same DispatcherManager +pipeline as api-gateway. +""" + import asyncio import argparse import logging @@ -9,82 +15,86 @@ from typing import Optional from urllib.parse import urlparse, urlunparse from .dispatcher import MessageDispatcher +from ..gateway.auth import IamAuth from ..gateway.config.receiver import ConfigReceiver -from ..base import get_pubsub +from ..base.pubsub import get_pubsub, add_pubsub_args +from ..base.logging import setup_logging, add_logging_args logger = logging.getLogger("rev_gateway") -logger.setLevel(logging.INFO) default_websocket = "ws://localhost:7650/out" +default_timeout = 600 class ReverseGateway: - - def __init__(self, websocket_uri: str = None, max_workers: int = 10, - pulsar_host: str = None, pulsar_api_key: str = None, - pulsar_listener: str = None): - # Set default WebSocket URI with environment variable support + + def __init__(self, **config): + websocket_uri = config.get("websocket_uri") if websocket_uri is None: websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket) - - # Parse and validate the WebSocket URI + parsed_uri = urlparse(websocket_uri) if parsed_uri.scheme not in ('ws', 'wss'): - raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}") + raise ValueError( + f"WebSocket URI must use ws:// or wss:// scheme, " + f"got: {parsed_uri.scheme}" + ) if not parsed_uri.netloc: - raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}") - - # Store parsed components for debugging/logging + raise ValueError( + f"WebSocket URI must include hostname, " + f"got: {websocket_uri}" + ) + self.websocket_uri = websocket_uri self.host = parsed_uri.hostname self.port = parsed_uri.port self.scheme = parsed_uri.scheme self.path = parsed_uri.path or "/ws" - - # Construct the full URL (in case path was missing) + if not parsed_uri.path: self.url = f"{self.scheme}://{parsed_uri.netloc}/ws" else: self.url = websocket_uri - - self.max_workers = max_workers + + self.max_workers = int(config.get("max_workers", 10)) + self.timeout = int(config.get("timeout", default_timeout)) self.ws: Optional[ClientWebSocketResponse] = None self.session: Optional[ClientSession] = None self.running = False self.reconnect_delay = 3.0 - - # Pulsar configuration - self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") - self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None) - self.pulsar_listener = pulsar_listener - # Create backend using factory - backend_params = { - 'pulsar_host': self.pulsar_host, - 'pulsar_api_key': self.pulsar_api_key, - 'pulsar_listener': self.pulsar_listener, - } - self.backend = get_pubsub(**backend_params) + self.backend = get_pubsub(**config) - # Initialize config receiver - self.config_receiver = ConfigReceiver(self.backend) + self.auth = IamAuth( + backend=self.backend, + id=config.get("id", "rev-gateway"), + ) + + self.config_receiver = ConfigReceiver( + self.backend, auth=self.auth, + ) + + self.dispatcher = MessageDispatcher( + self.max_workers, self.config_receiver, self.backend, + auth=self.auth, timeout=self.timeout, + ) - # Initialize dispatcher with config_receiver and backend - must be created after config_receiver - self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend) - async def connect(self) -> bool: try: if self.session is None: self.session = ClientSession() - + logger.info(f"Connecting to {self.url}") self.ws = await self.session.ws_connect(self.url) - logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}") + logger.info( + f"WebSocket connection established to " + f"{self.host}:{self.port or 'default'}" + ) return True - + except Exception as e: logger.error(f"Failed to connect to {self.url}: {e}") return False - + async def disconnect(self): if self.ws and not self.ws.closed: await self.ws.close() @@ -92,32 +102,31 @@ class ReverseGateway: await self.session.close() self.ws = None self.session = None - + async def send_message(self, message: dict): if self.ws and not self.ws.closed: try: await self.ws.send_str(json.dumps(message)) except Exception as e: logger.error(f"Failed to send message: {e}") - + async def handle_message(self, message: str): try: logger.debug(f"Received message: {message}") - + msg_data = json.loads(message) - response = await self.dispatcher.handle_message(msg_data) - - if response: - await self.send_message(response) - + await self.dispatcher.handle_message( + msg_data, self.send_message, + ) + except Exception as e: logger.error(f"Error handling message: {e}") - + async def listen(self): while self.running and self.ws and not self.ws.closed: try: msg = await self.ws.receive() - + if msg.type == WSMsgType.TEXT: await self.handle_message(msg.data) elif msg.type == WSMsgType.BINARY: @@ -125,31 +134,33 @@ class ReverseGateway: elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): logger.warning("WebSocket closed or error occurred") break - + except Exception as e: logger.error(f"Error in listen loop: {e}") break - + async def run(self): self.running = True logger.info("Starting reverse gateway") - - # Start config receiver - logger.info("Starting config receiver") + + await self.auth.start() await self.config_receiver.start() - + while self.running: try: if await self.connect(): await self.listen() else: - logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds") - + logger.warning( + f"Connection failed, retrying in " + f"{self.reconnect_delay} seconds" + ) + await self.disconnect() - + if self.running: await asyncio.sleep(self.reconnect_delay) - + except KeyboardInterrupt: logger.info("Shutdown requested") break @@ -157,77 +168,69 @@ class ReverseGateway: logger.error(f"Unexpected error: {e}") if self.running: await asyncio.sleep(self.reconnect_delay) - + await self.shutdown() - + async def shutdown(self): logger.info("Shutting down reverse gateway") self.running = False await self.dispatcher.shutdown() await self.disconnect() - # Close backend if hasattr(self, 'backend'): self.backend.close() - + def stop(self): self.running = False -def parse_args(): + +def run(): + parser = argparse.ArgumentParser( prog="reverse-gateway", - description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge" + description=__doc__, ) - + + parser.add_argument( + '--id', + default='rev-gateway', + help='Service identifier (default: rev-gateway)', + ) + parser.add_argument( '--websocket-uri', default=None, - help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)' + help=f'WebSocket URI to connect to (default: {default_websocket})', ) - + parser.add_argument( '--max-workers', type=int, default=10, - help='Maximum concurrent message handlers (default: 10)' + help='Maximum concurrent message handlers (default: 10)', ) - - parser.add_argument( - '-p', '--pulsar-host', - default=None, - help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)' - ) - - parser.add_argument( - '--pulsar-api-key', - default=None, - help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)' - ) - - parser.add_argument( - '--pulsar-listener', - default=None, - help='Pulsar listener name' - ) - - return parser.parse_args() -def run(): - args = parse_args() - - gateway = ReverseGateway( - websocket_uri=args.websocket_uri, - max_workers=args.max_workers, - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - pulsar_listener=args.pulsar_listener + parser.add_argument( + '--timeout', + type=int, + default=default_timeout, + help=f'Request timeout in seconds (default: {default_timeout})', ) - + + add_pubsub_args(parser) + add_logging_args(parser) + + args = parser.parse_args() + args = vars(args) + + setup_logging(args) + + gateway = ReverseGateway(**args) + logger.info(f"Starting reverse gateway:") logger.info(f" WebSocket URI: {gateway.url}") - logger.info(f" Max workers: {args.max_workers}") - logger.info(f" Pulsar host: {gateway.pulsar_host}") - + logger.info(f" Max workers: {gateway.max_workers}") + try: asyncio.run(gateway.run()) except KeyboardInterrupt: