Update rev-gateway for IAM integration (#940)

service.py:
- Constructor takes **config (same pattern as api-gateway) instead
  of individual args
- Creates IamAuth and calls await self.auth.start() before the
  message loop
- Passes auth to both ConfigReceiver and MessageDispatcher
- Uses add_pubsub_args / add_logging_args instead of hand-rolled
  Pulsar args
- Passes timeout through

dispatcher.py:
- Accepts auth and timeout parameters
- Passes both to DispatcherManager — fixes the missing auth argument
  that would have crashed on startup

The remote end's requests now go through the same IAM authentication
path as api-gateway. Token validation, workspace resolution, and
permissions all work identically regardless of which direction
initiated the connection.

Fixed tests — the test now passes auth and timeout to MessageDispatcher
and verifies they're forwarded to DispatcherManager.

Update rev gateway dispatcher to align with IAM.  A "token" parameter
must be passed with each message.

Fix websocket relay to align with rev-gateway changes, conforms to
the api-gateway protocol.
This commit is contained in:
cybermaggedon 2026-05-19 21:45:43 +01:00 committed by GitHub
parent 4e3bd85abc
commit e57f4669e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 914 additions and 865 deletions

View file

@ -3,208 +3,278 @@
WebSocket Relay Test Harness WebSocket Relay Test Harness
This script creates a relay server with two WebSocket endpoints: This script creates a relay server with two WebSocket endpoints:
- /in - for test clients to connect to - /in - for test clients to connect to (speaks api-gateway protocol)
- /out - for reverse gateway to connect to - /out - for reverse gateway to connect to (speaks rev-gateway protocol)
Messages are bidirectionally relayed between the two connections. Clients on /in authenticate with a first-frame auth message:
{"type": "auth", "token": "..."}
The relay stores the token and injects it into each subsequent message
before forwarding to /out. Responses from /out are forwarded back to
the originating /in connection unchanged.
Usage: Usage:
python websocket_relay.py [--port PORT] [--host HOST] python websocket_relay.py [--port PORT] [--host HOST]
""" """
import asyncio import asyncio
import json
import logging import logging
import argparse import argparse
from aiohttp import web, WSMsgType from aiohttp import web, WSMsgType
import weakref from typing import Dict, Optional
from typing import Optional, Set
# Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
) )
logger = logging.getLogger("websocket_relay") logger = logging.getLogger("websocket_relay")
class InConnection:
def __init__(self, ws, conn_id):
self.ws = ws
self.conn_id = conn_id
self.token: Optional[str] = None
self.authenticated = False
class WebSocketRelay: class WebSocketRelay:
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
def __init__(self): def __init__(self):
self.in_connections: Set = weakref.WeakSet() self.in_connections: Dict[str, InConnection] = {}
self.out_connections: Set = weakref.WeakSet() self.out_connections: set = set()
self._conn_counter = 0
def _next_conn_id(self):
self._conn_counter += 1
return f"conn-{self._conn_counter}"
async def handle_in_connection(self, request): async def handle_in_connection(self, request):
"""Handle incoming connections on /in endpoint"""
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
self.in_connections.add(ws) conn_id = self._next_conn_id()
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") conn = InConnection(ws, conn_id)
self.in_connections[conn_id] = conn
logger.info(
f"New 'in' connection {conn_id}. "
f"Total in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
try: try:
async for msg in ws: async for msg in ws:
if msg.type == WSMsgType.TEXT: if msg.type == WSMsgType.TEXT:
data = msg.data await self._handle_in_message(conn, msg.data)
logger.info(f"IN → OUT: {data}")
await self._forward_to_out(data)
elif msg.type == WSMsgType.BINARY:
data = msg.data
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
await self._forward_to_out(data, binary=True)
elif msg.type == WSMsgType.ERROR: elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}") logger.error(
f"WebSocket error on 'in' connection "
f"{conn_id}: {ws.exception()}"
)
break break
else: else:
break break
except Exception as e: except Exception as e:
logger.error(f"Error in 'in' connection handler: {e}") logger.error(
f"Error in 'in' connection {conn_id}: {e}"
)
finally: finally:
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") del self.in_connections[conn_id]
logger.info(
f"'in' connection {conn_id} closed. "
f"Remaining in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
return ws return ws
async def handle_out_connection(self, request): async def _handle_in_message(self, conn, data):
"""Handle outgoing connections on /out endpoint"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.out_connections.add(ws)
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
try: try:
async for msg in ws: message = json.loads(data)
if msg.type == WSMsgType.TEXT: except json.JSONDecodeError:
data = msg.data logger.warning(
logger.info(f"OUT → IN: {data}") f"{conn.conn_id}: received non-JSON message"
await self._forward_to_in(data) )
elif msg.type == WSMsgType.BINARY:
data = msg.data
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
await self._forward_to_in(data, binary=True)
elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
break
else:
break
except Exception as e:
logger.error(f"Error in 'out' connection handler: {e}")
finally:
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
return ws
async def _forward_to_out(self, data, binary=False):
"""Forward message from 'in' to all 'out' connections"""
if not self.out_connections:
logger.warning("No 'out' connections available to forward message")
return return
closed_connections = [] if isinstance(message, dict) and message.get("type") == "auth":
conn.token = message.get("token", "")
conn.authenticated = True
logger.info(f"{conn.conn_id}: authenticated")
await conn.ws.send_json({
"type": "auth-ok",
"workspace": "relayed",
})
return
if not conn.authenticated:
await conn.ws.send_json({
"error": {
"message": "auth required",
"type": "auth-required",
},
"complete": True,
})
return
message["token"] = conn.token
message["_relay_conn"] = conn.conn_id
forwarded = json.dumps(message)
logger.info(f"IN {conn.conn_id} → OUT: {forwarded}")
await self._forward_to_out(forwarded)
async def handle_out_connection(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.out_connections.add(ws)
logger.info(
f"New 'out' connection. "
f"Total in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
await self._handle_out_message(msg.data)
elif msg.type == WSMsgType.ERROR:
logger.error(
f"WebSocket error on 'out' connection: "
f"{ws.exception()}"
)
break
else:
break
except Exception as e:
logger.error(f"Error in 'out' connection: {e}")
finally:
self.out_connections.discard(ws)
logger.info(
f"'out' connection closed. "
f"Remaining in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
return ws
async def _handle_out_message(self, data):
try:
message = json.loads(data)
except json.JSONDecodeError:
logger.warning("OUT: received non-JSON message")
return
conn_id = message.pop("_relay_conn", None)
forwarded = json.dumps(message)
logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}")
if conn_id and conn_id in self.in_connections:
conn = self.in_connections[conn_id]
try:
if not conn.ws.closed:
await conn.ws.send_str(forwarded)
except Exception as e:
logger.error(
f"Error forwarding to 'in' {conn_id}: {e}"
)
else:
await self._broadcast_to_in(forwarded)
async def _broadcast_to_in(self, data):
closed = []
for conn_id, conn in list(self.in_connections.items()):
try:
if conn.ws.closed:
closed.append(conn_id)
continue
await conn.ws.send_str(data)
except Exception as e:
logger.error(
f"Error broadcasting to 'in' {conn_id}: {e}"
)
closed.append(conn_id)
for conn_id in closed:
self.in_connections.pop(conn_id, None)
async def _forward_to_out(self, data):
closed = []
for ws in list(self.out_connections): for ws in list(self.out_connections):
try: try:
if ws.closed: if ws.closed:
closed_connections.append(ws) closed.append(ws)
continue continue
await ws.send_str(data)
if binary:
await ws.send_bytes(data)
else:
await ws.send_str(data)
except Exception as e: except Exception as e:
logger.error(f"Error forwarding to 'out' connection: {e}") logger.error(f"Error forwarding to 'out': {e}")
closed_connections.append(ws) closed.append(ws)
for ws in closed:
# Clean up closed connections self.out_connections.discard(ws)
for ws in closed_connections:
if ws in self.out_connections:
self.out_connections.discard(ws)
async def _forward_to_in(self, data, binary=False):
"""Forward message from 'out' to all 'in' connections"""
if not self.in_connections:
logger.warning("No 'in' connections available to forward message")
return
closed_connections = []
for ws in list(self.in_connections):
try:
if ws.closed:
closed_connections.append(ws)
continue
if binary:
await ws.send_bytes(data)
else:
await ws.send_str(data)
except Exception as e:
logger.error(f"Error forwarding to 'in' connection: {e}")
closed_connections.append(ws)
# Clean up closed connections
for ws in closed_connections:
if ws in self.in_connections:
self.in_connections.discard(ws)
async def create_app(relay): async def create_app(relay):
"""Create the web application with routes"""
app = web.Application() app = web.Application()
# Add routes app.router.add_get('/in/api/v1/socket', relay.handle_in_connection)
app.router.add_get('/in', relay.handle_in_connection)
app.router.add_get('/out', relay.handle_out_connection) app.router.add_get('/out', relay.handle_out_connection)
# Add a simple status endpoint
async def status(request): async def status(request):
status_info = { return web.json_response({
'in_connections': len(relay.in_connections), 'in_connections': len(relay.in_connections),
'out_connections': len(relay.out_connections), 'out_connections': len(relay.out_connections),
'status': 'running' 'status': 'running',
} })
return web.json_response(status_info)
app.router.add_get('/status', status) app.router.add_get('/status', status)
app.router.add_get('/', status) # Root also shows status app.router.add_get('/', status)
return app return app
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="WebSocket Relay Test Harness" description="WebSocket Relay Test Harness"
) )
parser.add_argument( parser.add_argument(
'--host', '--host',
default='localhost', default='localhost',
help='Host to bind to (default: localhost)' help='Host to bind to (default: localhost)',
) )
parser.add_argument( parser.add_argument(
'--port', '--port',
type=int, type=int,
default=8080, default=8080,
help='Port to bind to (default: 8080)' help='Port to bind to (default: 8080)',
) )
parser.add_argument( parser.add_argument(
'--verbose', '-v', '--verbose', '-v',
action='store_true', action='store_true',
help='Enable verbose logging' help='Enable verbose logging',
) )
args = parser.parse_args() args = parser.parse_args()
if args.verbose: if args.verbose:
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)
relay = WebSocketRelay() relay = WebSocketRelay()
print(f"Starting WebSocket Relay on {args.host}:{args.port}") print(f"Starting WebSocket Relay on {args.host}:{args.port}")
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in") print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket")
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out") print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
print(f" Status: http://{args.host}:{args.port}/status") print(f" Status: http://{args.host}:{args.port}/status")
print() print()
print("Usage:") print("Client protocol (same as api-gateway):")
print(f" Test client connects to: ws://{args.host}:{args.port}/in") print(' 1. Connect to /in/api/v1/socket')
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out") print(' 2. Send: {"type": "auth", "token": "tg_..."}')
print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}')
print(' 4. Send requests as normal')
web.run_app(create_app(relay), host=args.host, port=args.port) web.run_app(create_app(relay), host=args.host, port=args.port)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
max_concurrent = 0 max_concurrent = 0
processing_event = asyncio.Event() processing_event = asyncio.Event()
async def slow_process(message): async def slow_process(message, sender):
nonlocal concurrent_count, max_concurrent nonlocal concurrent_count, max_concurrent
concurrent_count += 1 concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count) max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
concurrent_count -= 1 concurrent_count -= 1
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = slow_process dispatcher._process_message = slow_process
sender = AsyncMock()
# Launch more tasks than max_workers # Launch more tasks than max_workers
messages = [ messages = [
{"id": f"msg-{i}", "service": "test", "request": {}} {"id": f"msg-{i}", "service": "test", "request": {}}
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
] ]
tasks = [ tasks = [
asyncio.create_task(dispatcher.handle_message(m)) asyncio.create_task(dispatcher.handle_message(m, sender))
for m in messages for m in messages
] ]
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
original_process = dispatcher._process_message original_process = dispatcher._process_message
async def tracking_process(message): async def tracking_process(message, sender):
nonlocal task_was_tracked nonlocal task_was_tracked
# During processing, our task should be in active_tasks # During processing, our task should be in active_tasks
if len(dispatcher.active_tasks) > 0: if len(dispatcher.active_tasks) > 0:
task_was_tracked = True task_was_tracked = True
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = tracking_process dispatcher._process_message = tracking_process
await dispatcher.handle_message( await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}} {"id": "test", "service": "test", "request": {}},
AsyncMock(),
) )
assert task_was_tracked assert task_was_tracked
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
"""Semaphore should be released even if processing raises.""" """Semaphore should be released even if processing raises."""
dispatcher = MessageDispatcher(max_workers=2) dispatcher = MessageDispatcher(max_workers=2)
async def failing_process(message): async def failing_process(message, sender):
raise RuntimeError("process failed") raise RuntimeError("process failed")
dispatcher._process_message = failing_process dispatcher._process_message = failing_process
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
# Should not deadlock — semaphore must be released on error # Should not deadlock — semaphore must be released on error
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await dispatcher.handle_message( await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}} {"id": "test", "service": "test", "request": {}},
AsyncMock(),
) )
# Semaphore should be back at max # Semaphore should be back at max
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
order = [] order = []
async def ordered_process(message): async def ordered_process(message, sender):
msg_id = message["id"] msg_id = message["id"]
order.append(f"start-{msg_id}") order.append(f"start-{msg_id}")
await asyncio.sleep(0.02) await asyncio.sleep(0.02)
order.append(f"end-{msg_id}") order.append(f"end-{msg_id}")
return {"id": msg_id, "response": {"ok": True}}
dispatcher._process_message = ordered_process dispatcher._process_message = ordered_process
sender = AsyncMock()
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)] messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages] tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# With semaphore=1, each message should complete before next starts # With semaphore=1, each message should complete before next starts

View file

View file

@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
""" """
import pytest import pytest
from unittest.mock import MagicMock, AsyncMock, patch import asyncio
from unittest.mock import MagicMock, AsyncMock, patch, ANY
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestWebSocketResponder:
"""Test cases for WebSocketResponder class"""
def test_websocket_responder_initialization(self):
"""Test WebSocketResponder initialization"""
responder = WebSocketResponder()
assert responder.response is None
assert responder.completed is False
@pytest.mark.asyncio
async def test_websocket_responder_send_method(self):
"""Test WebSocketResponder send method"""
responder = WebSocketResponder()
test_response = {"data": "test response"}
# Call send method
await responder.send(test_response)
# Verify response was stored
assert responder.response == test_response
@pytest.mark.asyncio
async def test_websocket_responder_call_method(self):
"""Test WebSocketResponder __call__ method"""
responder = WebSocketResponder()
test_response = {"result": "success"}
test_completed = True
# Call the responder
await responder(test_response, test_completed)
# Verify response and completed status were set
assert responder.response == test_response
assert responder.completed == test_completed
@pytest.mark.asyncio
async def test_websocket_responder_call_method_with_false_completion(self):
"""Test WebSocketResponder __call__ method with incomplete response"""
responder = WebSocketResponder()
test_response = {"partial": "data"}
test_completed = False
# Call the responder
await responder(test_response, test_completed)
# Verify response was set and completed is True (since send() always sets completed=True)
assert responder.response == test_response
assert responder.completed is True
class TestMessageDispatcher: class TestMessageDispatcher:
"""Test cases for MessageDispatcher class""" """Test cases for MessageDispatcher class"""
def test_message_dispatcher_initialization_with_defaults(self): def test_message_dispatcher_initialization_with_defaults(self):
"""Test MessageDispatcher initialization with default parameters"""
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
assert dispatcher.max_workers == 10 assert dispatcher.max_workers == 10
assert dispatcher.semaphore._value == 10 assert dispatcher.semaphore._value == 10
assert dispatcher.active_tasks == set() assert dispatcher.active_tasks == set()
assert dispatcher.backend is None assert dispatcher.backend is None
assert dispatcher.auth is None
assert dispatcher.dispatcher_manager is None assert dispatcher.dispatcher_manager is None
assert len(dispatcher.service_mapping) > 0 assert len(dispatcher.service_mapping) > 0
def test_message_dispatcher_initialization_with_custom_workers(self): def test_message_dispatcher_initialization_with_custom_workers(self):
"""Test MessageDispatcher initialization with custom max_workers"""
dispatcher = MessageDispatcher(max_workers=5) dispatcher = MessageDispatcher(max_workers=5)
assert dispatcher.max_workers == 5 assert dispatcher.max_workers == 5
assert dispatcher.semaphore._value == 5 assert dispatcher.semaphore._value == 5
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager') @patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager): def test_message_dispatcher_initialization_with_backend(
"""Test MessageDispatcher initialization with pulsar_client and config_receiver""" self, mock_dispatcher_manager,
):
mock_backend = MagicMock() mock_backend = MagicMock()
mock_config_receiver = MagicMock() mock_config_receiver = MagicMock()
mock_auth = MagicMock()
mock_dispatcher_instance = MagicMock() mock_dispatcher_instance = MagicMock()
mock_dispatcher_manager.return_value = mock_dispatcher_instance mock_dispatcher_manager.return_value = mock_dispatcher_instance
dispatcher = MessageDispatcher( dispatcher = MessageDispatcher(
max_workers=8, max_workers=8,
config_receiver=mock_config_receiver, config_receiver=mock_config_receiver,
backend=mock_backend backend=mock_backend,
auth=mock_auth,
timeout=300,
) )
assert dispatcher.max_workers == 8 assert dispatcher.max_workers == 8
assert dispatcher.backend == mock_backend assert dispatcher.backend == mock_backend
assert dispatcher.auth == mock_auth
assert dispatcher.dispatcher_manager == mock_dispatcher_instance assert dispatcher.dispatcher_manager == mock_dispatcher_instance
mock_dispatcher_manager.assert_called_once_with( mock_dispatcher_manager.assert_called_once_with(
mock_backend, mock_config_receiver, prefix="rev-gateway" mock_backend, mock_config_receiver,
auth=mock_auth, prefix="rev-gateway", timeout=300,
) )
def test_message_dispatcher_service_mapping(self): def test_message_dispatcher_service_mapping(self):
"""Test MessageDispatcher service mapping contains expected services"""
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
expected_services = [ expected_services = [
"text-completion", "graph-rag", "agent", "embeddings", "text-completion", "graph-rag", "agent", "embeddings",
"graph-embeddings", "triples", "document-load", "text-load", "graph-embeddings", "triples", "document-load", "text-load",
"flow", "knowledge", "config", "librarian", "document-rag" "flow", "knowledge", "config", "librarian", "document-rag",
] ]
for service in expected_services: for service in expected_services:
assert service in dispatcher.service_mapping assert service in dispatcher.service_mapping
# Test specific mappings
assert dispatcher.service_mapping["text-completion"] == "text-completion"
assert dispatcher.service_mapping["document-load"] == "document" assert dispatcher.service_mapping["document-load"] == "document"
assert dispatcher.service_mapping["text-load"] == "text-document" assert dispatcher.service_mapping["text-load"] == "text-document"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self): async def test_handle_message_without_dispatcher_manager(self):
"""Test MessageDispatcher handle_message without dispatcher manager"""
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.auth = MagicMock()
test_message = { dispatcher.auth.authenticate = AsyncMock(
"id": "test-123", return_value=MagicMock(workspace="default")
"service": "test-service", )
"request": {"data": "test"}
} sender = AsyncMock()
result = await dispatcher.handle_message(test_message) await dispatcher.handle_message(
{"id": "test-1", "service": "test", "request": {}},
assert result["id"] == "test-123" sender,
assert "error" in result["response"] )
assert "DispatcherManager not available" in result["response"]["error"]
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-1"
assert sent["error"]["message"] == "DispatcherManager not available"
assert sent["error"]["type"] == "error"
assert sent["complete"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_with_exception(self): async def test_handle_message_auth_failure(self):
"""Test MessageDispatcher handle_message with exception during processing"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
test_message = { side_effect=Exception("auth failure")
"id": "test-456", )
"service": "text-completion", dispatcher.dispatcher_manager = MagicMock()
"request": {"prompt": "test"}
} sender = AsyncMock()
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): await dispatcher.handle_message(
result = await dispatcher.handle_message(test_message) {"id": "test-2", "token": "bad", "service": "test", "request": {}},
sender,
assert result["id"] == "test-456" )
assert "error" in result["response"]
assert "Test error" in result["response"]["error"] sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-2"
assert "auth failure" in sent["error"]["message"]
assert sent["complete"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_global_service(self): async def test_handle_message_global_service(self):
"""Test MessageDispatcher handle_message with global service""" mock_dm = MagicMock()
mock_dispatcher_manager = MagicMock() mock_dm.invoke_global_service = AsyncMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"result": "success"}
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
test_message = { dispatcher.auth.authenticate = AsyncMock(
"id": "test-789", return_value=MagicMock(workspace="ws1")
"service": "text-completion", )
"request": {"prompt": "hello"}
} sender = AsyncMock()
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): with patch(
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): 'trustgraph.gateway.dispatch.manager.global_dispatchers',
result = await dispatcher.handle_message(test_message) {"text-completion": True},
):
assert result["id"] == "test-789" await dispatcher.handle_message(
assert result["response"] == {"result": "success"} {
mock_dispatcher_manager.invoke_global_service.assert_called_once() "id": "test-3",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hello"},
},
sender,
)
mock_dm.invoke_global_service.assert_called_once()
args, kwargs = mock_dm.invoke_global_service.call_args
assert args[0] == {"prompt": "hello"}
assert args[2] == "text-completion"
assert kwargs["workspace"] == "ws1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_flow_service(self): async def test_handle_message_flow_service(self):
"""Test MessageDispatcher handle_message with flow service""" mock_dm = MagicMock()
mock_dispatcher_manager = MagicMock() mock_dm.invoke_flow_service = AsyncMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"data": "flow_result"}
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
test_message = { dispatcher.auth.authenticate = AsyncMock(
"id": "test-flow-123", return_value=MagicMock(workspace="ws2")
"service": "document-rag", )
"request": {"query": "test"},
"flow": "custom-flow" sender = AsyncMock()
}
with patch(
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): 'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): ):
result = await dispatcher.handle_message(test_message) await dispatcher.handle_message(
{
assert result["id"] == "test-flow-123" "id": "test-4",
assert result["response"] == {"data": "flow_result"} "token": "tg_key",
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with( "service": "document-rag",
{"query": "test"}, mock_responder, "custom-flow", "document-rag" "request": {"query": "test"},
"flow": "my-flow",
},
sender,
)
mock_dm.invoke_flow_service.assert_called_once_with(
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_incomplete_response(self): async def test_handle_message_responder_sends_frames(self):
"""Test MessageDispatcher handle_message with incomplete response""" mock_dm = MagicMock()
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock() async def fake_invoke(data, responder, svc, workspace=None):
mock_responder = MagicMock() await responder({"partial": 1}, False)
mock_responder.completed = False await responder({"partial": 2}, True)
mock_responder.response = None
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
test_message = { dispatcher.auth.authenticate = AsyncMock(
"id": "test-incomplete", return_value=MagicMock(workspace="ws1")
"service": "agent", )
"request": {"input": "test"}
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-5",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hi"},
},
sender,
)
assert sender.call_count == 2
first = sender.call_args_list[0][0][0]
second = sender.call_args_list[1][0][0]
assert first == {
"id": "test-5", "response": {"partial": 1}, "complete": False,
}
assert second == {
"id": "test-5", "response": {"partial": 2}, "complete": True,
} }
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-incomplete"
assert result["response"] == {"error": "No response received"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_shutdown(self): async def test_handle_message_workspace_from_identity(self):
"""Test MessageDispatcher shutdown method""" mock_dm = MagicMock()
import asyncio mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dm
# Create actual async tasks dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="derived-ws")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-6",
"token": "tg_key",
"service": "agent",
"request": {"question": "test"},
"flow": "default",
},
sender,
)
args = mock_dm.invoke_flow_service.call_args[0]
assert args[2] == "derived-ws"
@pytest.mark.asyncio
async def test_shutdown(self):
dispatcher = MessageDispatcher()
async def dummy_task(): async def dummy_task():
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
return "done"
task1 = asyncio.create_task(dummy_task()) task1 = asyncio.create_task(dummy_task())
task2 = asyncio.create_task(dummy_task()) task2 = asyncio.create_task(dummy_task())
dispatcher.active_tasks = {task1, task2} dispatcher.active_tasks = {task1, task2}
# Call shutdown
await dispatcher.shutdown() await dispatcher.shutdown()
# Verify tasks were completed
assert task1.done() assert task1.done()
assert task2.done() assert task2.done()
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_shutdown_with_no_tasks(self): async def test_shutdown_with_no_tasks(self):
"""Test MessageDispatcher shutdown with no active tasks"""
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
# Call shutdown with no active tasks
await dispatcher.shutdown() await dispatcher.shutdown()
# Should complete without error assert dispatcher.active_tasks == set()
assert dispatcher.active_tasks == set()

View file

@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock
from aiohttp import WSMsgType, ClientWebSocketResponse from aiohttp import WSMsgType, ClientWebSocketResponse
import json import json
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run from trustgraph.rev_gateway.service import ReverseGateway, run
MOCK_PATCHES = [
'trustgraph.rev_gateway.service.IamAuth',
'trustgraph.rev_gateway.service.ConfigReceiver',
'trustgraph.rev_gateway.service.MessageDispatcher',
'trustgraph.rev_gateway.service.get_pubsub',
]
def make_gateway(**overrides):
config = {"websocket_uri": "ws://localhost:7650/out"}
config.update(overrides)
return ReverseGateway(**config)
class TestReverseGateway: class TestReverseGateway:
"""Test cases for ReverseGateway class""" """Test cases for ReverseGateway class"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with default parameters""" def test_reverse_gateway_initialization_defaults(
mock_backend = MagicMock() self, mock_get_pubsub, mock_dispatcher,
mock_get_pubsub.return_value = mock_backend mock_config_receiver, mock_iam_auth,
):
gateway = ReverseGateway() mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
assert gateway.websocket_uri == "ws://localhost:7650/out" assert gateway.websocket_uri == "ws://localhost:7650/out"
assert gateway.host == "localhost" assert gateway.host == "localhost"
assert gateway.port == 7650 assert gateway.port == 7650
@ -33,25 +49,22 @@ class TestReverseGateway:
assert gateway.max_workers == 10 assert gateway.max_workers == 10
assert gateway.running is False assert gateway.running is False
assert gateway.reconnect_delay == 3.0 assert gateway.reconnect_delay == 3.0
assert gateway.pulsar_host == "pulsar://pulsar:6650"
assert gateway.pulsar_api_key is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with custom parameters""" def test_reverse_gateway_initialization_custom_params(
mock_backend = MagicMock() self, mock_get_pubsub, mock_dispatcher,
mock_get_pubsub.return_value = mock_backend mock_config_receiver, mock_iam_auth,
):
gateway = ReverseGateway( mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(
websocket_uri="wss://example.com:8080/websocket", websocket_uri="wss://example.com:8080/websocket",
max_workers=20, max_workers=20,
pulsar_host="pulsar://custom:6650",
pulsar_api_key="test-key",
pulsar_listener="test-listener"
) )
assert gateway.websocket_uri == "wss://example.com:8080/websocket" assert gateway.websocket_uri == "wss://example.com:8080/websocket"
assert gateway.host == "example.com" assert gateway.host == "example.com"
assert gateway.port == 8080 assert gateway.port == 8080
@ -59,340 +72,360 @@ class TestReverseGateway:
assert gateway.path == "/websocket" assert gateway.path == "/websocket"
assert gateway.url == "wss://example.com:8080/websocket" assert gateway.url == "wss://example.com:8080/websocket"
assert gateway.max_workers == 20 assert gateway.max_workers == 20
assert gateway.pulsar_host == "pulsar://custom:6650"
assert gateway.pulsar_api_key == "test-key"
assert gateway.pulsar_listener == "test-listener"
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with WebSocket URI missing path""" def test_reverse_gateway_initialization_with_missing_path(
mock_backend = MagicMock() self, mock_get_pubsub, mock_dispatcher,
mock_get_pubsub.return_value = mock_backend mock_config_receiver, mock_iam_auth,
):
gateway = ReverseGateway(websocket_uri="ws://example.com") mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(websocket_uri="ws://example.com")
assert gateway.path == "/ws" assert gateway.path == "/ws"
assert gateway.url == "ws://example.com/ws" assert gateway.url == "ws://example.com/ws"
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with invalid WebSocket scheme""" def test_reverse_gateway_initialization_invalid_scheme(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"): with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
ReverseGateway(websocket_uri="http://example.com") make_gateway(websocket_uri="http://example.com")
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with missing hostname""" def test_reverse_gateway_initialization_missing_hostname(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must include hostname"): with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
ReverseGateway(websocket_uri="ws://") make_gateway(websocket_uri="ws://")
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway creates backend with authentication""" def test_reverse_gateway_iam_auth_created(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock() mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway( gateway = make_gateway(id="test-rev-gw")
pulsar_api_key="test-key",
pulsar_listener="test-listener" mock_iam_auth.assert_called_once_with(
backend=mock_backend,
id="test-rev-gw",
) )
# Verify get_pubsub was called with the correct parameters @patch(*MOCK_PATCHES[0:1])
mock_get_pubsub.assert_called_once_with( @patch(*MOCK_PATCHES[1:2])
pulsar_host="pulsar://pulsar:6650", @patch(*MOCK_PATCHES[2:3])
pulsar_api_key="test-key", @patch(*MOCK_PATCHES[3:4])
pulsar_listener="test-listener" def test_reverse_gateway_config_receiver_gets_auth(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_auth_instance = MagicMock()
mock_iam_auth.return_value = mock_auth_instance
gateway = make_gateway()
mock_config_receiver.assert_called_once_with(
mock_backend, auth=mock_auth_instance,
) )
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession') @patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_connect_success(
"""Test ReverseGateway successful connection""" self, mock_session_class, mock_get_pubsub,
mock_backend = MagicMock() mock_dispatcher, mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock() mock_session = AsyncMock()
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_session.ws_connect.return_value = mock_ws mock_session.ws_connect.return_value = mock_ws
mock_session_class.return_value = mock_session mock_session_class.return_value = mock_session
gateway = ReverseGateway() gateway = make_gateway()
result = await gateway.connect() result = await gateway.connect()
assert result is True assert result is True
assert gateway.session == mock_session assert gateway.session == mock_session
assert gateway.ws == mock_ws assert gateway.ws == mock_ws
mock_session.ws_connect.assert_called_once_with(gateway.url) mock_session.ws_connect.assert_called_once_with(gateway.url)
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession') @patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_connect_failure(
"""Test ReverseGateway connection failure""" self, mock_session_class, mock_get_pubsub,
mock_backend = MagicMock() mock_dispatcher, mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session.ws_connect.side_effect = Exception("Connection failed") mock_session.ws_connect.side_effect = Exception("Connection failed")
mock_session_class.return_value = mock_session mock_session_class.return_value = mock_session
gateway = ReverseGateway() gateway = make_gateway()
result = await gateway.connect() result = await gateway.connect()
assert result is False assert result is False
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_disconnect(
"""Test ReverseGateway disconnect""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
# Mock websocket and session
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session.closed = False mock_session.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
gateway.session = mock_session gateway.session = mock_session
await gateway.disconnect() await gateway.disconnect()
mock_ws.close.assert_called_once() mock_ws.close.assert_called_once()
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
assert gateway.ws is None assert gateway.ws is None
assert gateway.session is None assert gateway.session is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_send_message(
"""Test ReverseGateway send message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
# Mock websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"} test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message) await gateway.send_message(test_message)
mock_ws.send_str.assert_called_once_with(json.dumps(test_message)) mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_send_message_closed_connection(
"""Test ReverseGateway send message with closed connection""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
# Mock closed websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = True mock_ws.closed = True
gateway.ws = mock_ws gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"} test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message) await gateway.send_message(test_message)
# Should not call send_str on closed connection
mock_ws.send_str.assert_not_called() mock_ws.send_str.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_handle_message(
"""Test ReverseGateway handle message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
mock_dispatcher_instance = AsyncMock()
mock_dispatcher_instance.handle_message.return_value = {"response": "success"} mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway() gateway = make_gateway()
# Mock send_message
gateway.send_message = AsyncMock()
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
await gateway.handle_message(test_message)
mock_dispatcher_instance.handle_message.assert_called_once_with({
"id": "test",
"service": "test-service",
"request": {"data": "test"}
})
gateway.send_message.assert_called_once_with({"response": "success"})
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message with invalid JSON"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock send_message
gateway.send_message = AsyncMock() gateway.send_message = AsyncMock()
test_message = 'invalid json' test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
# Should not raise exception
await gateway.handle_message(test_message) await gateway.handle_message(test_message)
# Should not call send_message due to error mock_dispatcher_instance.handle_message.assert_called_once_with(
{
"id": "test",
"service": "test-service",
"request": {"data": "test"},
},
gateway.send_message,
)
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.send_message = AsyncMock()
await gateway.handle_message('invalid json')
gateway.send_message.assert_not_called() gateway.send_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_text_message(
"""Test ReverseGateway listen with text message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True gateway.running = True
# Mock websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock() gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.type = WSMsgType.TEXT mock_msg.type = WSMsgType.TEXT
mock_msg.data = '{"test": "message"}' mock_msg.data = '{"test": "message"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen() await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "message"}') gateway.handle_message.assert_called_once_with('{"test": "message"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_binary_message(
"""Test ReverseGateway listen with binary message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True gateway.running = True
# Mock websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock() gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.type = WSMsgType.BINARY mock_msg.type = WSMsgType.BINARY
mock_msg.data = b'{"test": "binary"}' mock_msg.data = b'{"test": "binary"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen() await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "binary"}') gateway.handle_message.assert_called_once_with('{"test": "binary"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_close_message(
"""Test ReverseGateway listen with close message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True gateway.running = True
# Mock websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock() gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.type = WSMsgType.CLOSE mock_msg.type = WSMsgType.CLOSE
# Mock receive to return close message
mock_ws.receive.return_value = mock_msg mock_ws.receive.return_value = mock_msg
await gateway.listen() await gateway.listen()
# Should not call handle_message for close message
gateway.handle_message.assert_not_called() gateway.handle_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_shutdown(
"""Test ReverseGateway shutdown""" self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock() mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock() mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway() gateway = make_gateway()
gateway.running = True gateway.running = True
# Mock disconnect
gateway.disconnect = AsyncMock() gateway.disconnect = AsyncMock()
await gateway.shutdown() await gateway.shutdown()
@ -402,46 +435,50 @@ class TestReverseGateway:
gateway.disconnect.assert_called_once() gateway.disconnect.assert_called_once()
mock_backend.close.assert_called_once() mock_backend.close.assert_called_once()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway stop""" def test_reverse_gateway_stop(
mock_backend = MagicMock() self, mock_get_pubsub, mock_dispatcher,
mock_get_pubsub.return_value = mock_backend mock_config_receiver, mock_iam_auth,
):
gateway = ReverseGateway() mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True gateway.running = True
gateway.stop() gateway.stop()
assert gateway.running is False assert gateway.running is False
class TestReverseGatewayRun: class TestReverseGatewayRun:
"""Test cases for ReverseGateway run method""" """Test cases for ReverseGateway run method"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_run_successful_cycle(
"""Test ReverseGateway run method with successful connect/listen cycle""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
mock_auth_instance = AsyncMock()
mock_iam_auth.return_value = mock_auth_instance
mock_config_receiver_instance = AsyncMock() mock_config_receiver_instance = AsyncMock()
mock_config_receiver.return_value = mock_config_receiver_instance mock_config_receiver.return_value = mock_config_receiver_instance
gateway = ReverseGateway() gateway = make_gateway()
# Mock methods
gateway.connect = AsyncMock(return_value=True)
gateway.listen = AsyncMock() gateway.listen = AsyncMock()
gateway.disconnect = AsyncMock() gateway.disconnect = AsyncMock()
gateway.shutdown = AsyncMock() gateway.shutdown = AsyncMock()
# Stop after one iteration
call_count = 0 call_count = 0
async def mock_connect(): async def mock_connect():
nonlocal call_count nonlocal call_count
@ -451,91 +488,13 @@ class TestReverseGatewayRun:
else: else:
gateway.running = False gateway.running = False
return False return False
gateway.connect = mock_connect gateway.connect = mock_connect
await gateway.run() await gateway.run()
mock_auth_instance.start.assert_called_once()
mock_config_receiver_instance.start.assert_called_once() mock_config_receiver_instance.start.assert_called_once()
gateway.listen.assert_called_once() gateway.listen.assert_called_once()
# disconnect is called twice: once in the main loop, once in shutdown
assert gateway.disconnect.call_count == 2 assert gateway.disconnect.call_count == 2
gateway.shutdown.assert_called_once() gateway.shutdown.assert_called_once()
class TestReverseGatewayArgs:
"""Test cases for argument parsing and run function"""
def test_parse_args_defaults(self):
"""Test parse_args with default values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway']
try:
args = parse_args()
assert args.websocket_uri is None
assert args.max_workers == 10
assert args.pulsar_host is None
assert args.pulsar_api_key is None
assert args.pulsar_listener is None
finally:
sys.argv = original_argv
def test_parse_args_custom_values(self):
"""Test parse_args with custom values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = [
'reverse-gateway',
'--websocket-uri', 'ws://custom:8080/ws',
'--max-workers', '20',
'--pulsar-host', 'pulsar://custom:6650',
'--pulsar-api-key', 'test-key',
'--pulsar-listener', 'test-listener'
]
try:
args = parse_args()
assert args.websocket_uri == 'ws://custom:8080/ws'
assert args.max_workers == 20
assert args.pulsar_host == 'pulsar://custom:6650'
assert args.pulsar_api_key == 'test-key'
assert args.pulsar_listener == 'test-listener'
finally:
sys.argv = original_argv
@patch('trustgraph.rev_gateway.service.ReverseGateway')
@patch('asyncio.run')
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
"""Test run function"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway', '--max-workers', '15']
try:
mock_gateway_instance = MagicMock()
mock_gateway_instance.url = "ws://localhost:7650/out"
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
mock_gateway_class.return_value = mock_gateway_instance
run()
mock_gateway_class.assert_called_once_with(
websocket_uri=None,
max_workers=15,
pulsar_host=None,
pulsar_api_key=None,
pulsar_listener=None
)
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
finally:
sys.argv = original_argv

View file

@ -1,130 +1,140 @@
import asyncio import asyncio
import logging import logging
import uuid import uuid
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, Callable, Awaitable
from trustgraph.messaging import TranslatorRegistry
from ..gateway.dispatch.manager import DispatcherManager from ..gateway.dispatch.manager import DispatcherManager
logger = logging.getLogger("dispatcher") logger = logging.getLogger("dispatcher")
logger.setLevel(logging.INFO)
class WebSocketResponder:
"""Simple responder that captures response for websocket return""" class _TokenShim:
def __init__(self): def __init__(self, token):
self.response = None self.headers = (
self.completed = False {"Authorization": f"Bearer {token}"} if token else {}
)
async def send(self, data):
"""Capture the response data"""
self.response = data
self.completed = True
async def __call__(self, data, final=False):
"""Make the responder callable for compatibility with requestor"""
await self.send(data)
if final:
self.completed = True
class MessageDispatcher: class MessageDispatcher:
def __init__(self, max_workers: int = 10, config_receiver=None, backend=None): def __init__(self, max_workers=10, config_receiver=None, backend=None,
auth=None, timeout=120):
self.max_workers = max_workers self.max_workers = max_workers
self.semaphore = asyncio.Semaphore(max_workers) self.semaphore = asyncio.Semaphore(max_workers)
self.active_tasks = set() self.active_tasks = set()
self.backend = backend self.backend = backend
self.auth = auth
# Use DispatcherManager for flow and service management if backend and config_receiver and auth:
if backend and config_receiver: self.dispatcher_manager = DispatcherManager(
self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway") backend, config_receiver,
auth=auth,
prefix="rev-gateway",
timeout=timeout,
)
else: else:
self.dispatcher_manager = None self.dispatcher_manager = None
logger.warning("No backend or config_receiver provided - using fallback mode") logger.warning(
"Missing backend, config_receiver, or auth "
# Service name mapping from websocket protocol to translator registry "— using fallback mode"
)
self.service_mapping = { self.service_mapping = {
"text-completion": "text-completion", "text-completion": "text-completion",
"graph-rag": "graph-rag", "graph-rag": "graph-rag",
"agent": "agent", "agent": "agent",
"embeddings": "embeddings", "embeddings": "embeddings",
"graph-embeddings": "graph-embeddings", "graph-embeddings": "graph-embeddings",
"triples": "triples", "triples": "triples",
"document-load": "document", "document-load": "document",
"text-load": "text-document", "text-load": "text-document",
"flow": "flow", "flow": "flow",
"knowledge": "knowledge", "knowledge": "knowledge",
"config": "config", "config": "config",
"librarian": "librarian", "librarian": "librarian",
"document-rag": "document-rag" "document-rag": "document-rag",
} }
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]: async def handle_message(
self, message: Dict[Any, Any],
sender: Callable[[dict], Awaitable[None]],
):
async with self.semaphore: async with self.semaphore:
task = asyncio.create_task(self._process_message(message)) task = asyncio.create_task(
self._process_message(message, sender)
)
self.active_tasks.add(task) self.active_tasks.add(task)
try: try:
result = await task await task
return result
finally: finally:
self.active_tasks.discard(task) self.active_tasks.discard(task)
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]: async def _authenticate(self, token):
if not self.auth:
raise RuntimeError("Auth not configured")
return await self.auth.authenticate(_TokenShim(token))
async def _process_message(
self, message: Dict[Any, Any],
sender: Callable[[dict], Awaitable[None]],
):
request_id = message.get('id', str(uuid.uuid4())) request_id = message.get('id', str(uuid.uuid4()))
service = message.get('service') service = message.get('service')
request_data = message.get('request', {}) request_data = message.get('request', {})
flow_id = message.get('flow', 'default') # Default flow token = message.get('token', '')
flow_id = message.get('flow', 'default')
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
logger.info(
f"Processing message {request_id} for service "
f"{service} on flow {flow_id}"
)
try: try:
if not self.dispatcher_manager: if not self.dispatcher_manager:
raise RuntimeError("DispatcherManager not available - backend and config_receiver required") raise RuntimeError(
"DispatcherManager not available"
# Use DispatcherManager for flow-based processing )
responder = WebSocketResponder()
identity = await self._authenticate(token)
# Map websocket service name to dispatcher service name workspace = identity.workspace
async def responder(resp, fin):
await sender({
"id": request_id,
"response": resp,
"complete": fin,
})
dispatcher_service = self.service_mapping.get(service, service) dispatcher_service = self.service_mapping.get(service, service)
# Check if this is a global service or flow service
from ..gateway.dispatch.manager import global_dispatchers from ..gateway.dispatch.manager import global_dispatchers
if dispatcher_service in global_dispatchers: if dispatcher_service in global_dispatchers:
# Use global service dispatcher
await self.dispatcher_manager.invoke_global_service( await self.dispatcher_manager.invoke_global_service(
request_data, responder, dispatcher_service request_data, responder, dispatcher_service,
workspace=workspace,
) )
else: else:
# Use DispatcherManager to process the request through Pulsar queues
await self.dispatcher_manager.invoke_flow_service( await self.dispatcher_manager.invoke_flow_service(
request_data, responder, flow_id, dispatcher_service request_data, responder, workspace, flow_id,
dispatcher_service,
) )
# Get the response from the responder
if responder.completed:
response_data = responder.response
else:
response_data = {'error': 'No response received'}
response = {
'id': request_id,
'response': response_data
}
except Exception as e: except Exception as e:
logger.error(f"Error processing message {request_id}: {e}") logger.error(f"Error processing message {request_id}: {e}")
response = { await sender({
'id': request_id, "id": request_id,
'response': {'error': str(e)} "error": {"message": str(e), "type": "error"},
} "complete": True,
})
logger.info(f"Completed processing message {request_id}") logger.info(f"Completed processing message {request_id}")
return response
async def shutdown(self): async def shutdown(self):
if self.active_tasks: if self.active_tasks:
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete") logger.info(
f"Waiting for {len(self.active_tasks)} active "
f"tasks to complete"
)
await asyncio.gather(*self.active_tasks, return_exceptions=True) await asyncio.gather(*self.active_tasks, return_exceptions=True)
# DispatcherManager handles its own cleanup
logger.info("Dispatcher shutdown complete") logger.info("Dispatcher shutdown complete")

View file

@ -1,3 +1,9 @@
"""
Reverse gateway. Initiates outbound WebSocket connections to a remote
relay and dispatches incoming requests through the same DispatcherManager
pipeline as api-gateway.
"""
import asyncio import asyncio
import argparse import argparse
import logging import logging
@ -9,82 +15,86 @@ from typing import Optional
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from .dispatcher import MessageDispatcher from .dispatcher import MessageDispatcher
from ..gateway.auth import IamAuth
from ..gateway.config.receiver import ConfigReceiver from ..gateway.config.receiver import ConfigReceiver
from ..base import get_pubsub from ..base.pubsub import get_pubsub, add_pubsub_args
from ..base.logging import setup_logging, add_logging_args
logger = logging.getLogger("rev_gateway") logger = logging.getLogger("rev_gateway")
logger.setLevel(logging.INFO)
default_websocket = "ws://localhost:7650/out" default_websocket = "ws://localhost:7650/out"
default_timeout = 600
class ReverseGateway: class ReverseGateway:
def __init__(self, websocket_uri: str = None, max_workers: int = 10, def __init__(self, **config):
pulsar_host: str = None, pulsar_api_key: str = None, websocket_uri = config.get("websocket_uri")
pulsar_listener: str = None):
# Set default WebSocket URI with environment variable support
if websocket_uri is None: if websocket_uri is None:
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket) websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
# Parse and validate the WebSocket URI
parsed_uri = urlparse(websocket_uri) parsed_uri = urlparse(websocket_uri)
if parsed_uri.scheme not in ('ws', 'wss'): if parsed_uri.scheme not in ('ws', 'wss'):
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}") raise ValueError(
f"WebSocket URI must use ws:// or wss:// scheme, "
f"got: {parsed_uri.scheme}"
)
if not parsed_uri.netloc: if not parsed_uri.netloc:
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}") raise ValueError(
f"WebSocket URI must include hostname, "
# Store parsed components for debugging/logging f"got: {websocket_uri}"
)
self.websocket_uri = websocket_uri self.websocket_uri = websocket_uri
self.host = parsed_uri.hostname self.host = parsed_uri.hostname
self.port = parsed_uri.port self.port = parsed_uri.port
self.scheme = parsed_uri.scheme self.scheme = parsed_uri.scheme
self.path = parsed_uri.path or "/ws" self.path = parsed_uri.path or "/ws"
# Construct the full URL (in case path was missing)
if not parsed_uri.path: if not parsed_uri.path:
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws" self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
else: else:
self.url = websocket_uri self.url = websocket_uri
self.max_workers = max_workers self.max_workers = int(config.get("max_workers", 10))
self.timeout = int(config.get("timeout", default_timeout))
self.ws: Optional[ClientWebSocketResponse] = None self.ws: Optional[ClientWebSocketResponse] = None
self.session: Optional[ClientSession] = None self.session: Optional[ClientSession] = None
self.running = False self.running = False
self.reconnect_delay = 3.0 self.reconnect_delay = 3.0
# Pulsar configuration
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
self.pulsar_listener = pulsar_listener
# Create backend using factory self.backend = get_pubsub(**config)
backend_params = {
'pulsar_host': self.pulsar_host,
'pulsar_api_key': self.pulsar_api_key,
'pulsar_listener': self.pulsar_listener,
}
self.backend = get_pubsub(**backend_params)
# Initialize config receiver self.auth = IamAuth(
self.config_receiver = ConfigReceiver(self.backend) backend=self.backend,
id=config.get("id", "rev-gateway"),
)
self.config_receiver = ConfigReceiver(
self.backend, auth=self.auth,
)
self.dispatcher = MessageDispatcher(
self.max_workers, self.config_receiver, self.backend,
auth=self.auth, timeout=self.timeout,
)
# Initialize dispatcher with config_receiver and backend - must be created after config_receiver
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend)
async def connect(self) -> bool: async def connect(self) -> bool:
try: try:
if self.session is None: if self.session is None:
self.session = ClientSession() self.session = ClientSession()
logger.info(f"Connecting to {self.url}") logger.info(f"Connecting to {self.url}")
self.ws = await self.session.ws_connect(self.url) self.ws = await self.session.ws_connect(self.url)
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}") logger.info(
f"WebSocket connection established to "
f"{self.host}:{self.port or 'default'}"
)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to {self.url}: {e}") logger.error(f"Failed to connect to {self.url}: {e}")
return False return False
async def disconnect(self): async def disconnect(self):
if self.ws and not self.ws.closed: if self.ws and not self.ws.closed:
await self.ws.close() await self.ws.close()
@ -92,32 +102,31 @@ class ReverseGateway:
await self.session.close() await self.session.close()
self.ws = None self.ws = None
self.session = None self.session = None
async def send_message(self, message: dict): async def send_message(self, message: dict):
if self.ws and not self.ws.closed: if self.ws and not self.ws.closed:
try: try:
await self.ws.send_str(json.dumps(message)) await self.ws.send_str(json.dumps(message))
except Exception as e: except Exception as e:
logger.error(f"Failed to send message: {e}") logger.error(f"Failed to send message: {e}")
async def handle_message(self, message: str): async def handle_message(self, message: str):
try: try:
logger.debug(f"Received message: {message}") logger.debug(f"Received message: {message}")
msg_data = json.loads(message) msg_data = json.loads(message)
response = await self.dispatcher.handle_message(msg_data) await self.dispatcher.handle_message(
msg_data, self.send_message,
if response: )
await self.send_message(response)
except Exception as e: except Exception as e:
logger.error(f"Error handling message: {e}") logger.error(f"Error handling message: {e}")
async def listen(self): async def listen(self):
while self.running and self.ws and not self.ws.closed: while self.running and self.ws and not self.ws.closed:
try: try:
msg = await self.ws.receive() msg = await self.ws.receive()
if msg.type == WSMsgType.TEXT: if msg.type == WSMsgType.TEXT:
await self.handle_message(msg.data) await self.handle_message(msg.data)
elif msg.type == WSMsgType.BINARY: elif msg.type == WSMsgType.BINARY:
@ -125,31 +134,33 @@ class ReverseGateway:
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
logger.warning("WebSocket closed or error occurred") logger.warning("WebSocket closed or error occurred")
break break
except Exception as e: except Exception as e:
logger.error(f"Error in listen loop: {e}") logger.error(f"Error in listen loop: {e}")
break break
async def run(self): async def run(self):
self.running = True self.running = True
logger.info("Starting reverse gateway") logger.info("Starting reverse gateway")
# Start config receiver await self.auth.start()
logger.info("Starting config receiver")
await self.config_receiver.start() await self.config_receiver.start()
while self.running: while self.running:
try: try:
if await self.connect(): if await self.connect():
await self.listen() await self.listen()
else: else:
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds") logger.warning(
f"Connection failed, retrying in "
f"{self.reconnect_delay} seconds"
)
await self.disconnect() await self.disconnect()
if self.running: if self.running:
await asyncio.sleep(self.reconnect_delay) await asyncio.sleep(self.reconnect_delay)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Shutdown requested") logger.info("Shutdown requested")
break break
@ -157,77 +168,69 @@ class ReverseGateway:
logger.error(f"Unexpected error: {e}") logger.error(f"Unexpected error: {e}")
if self.running: if self.running:
await asyncio.sleep(self.reconnect_delay) await asyncio.sleep(self.reconnect_delay)
await self.shutdown() await self.shutdown()
async def shutdown(self): async def shutdown(self):
logger.info("Shutting down reverse gateway") logger.info("Shutting down reverse gateway")
self.running = False self.running = False
await self.dispatcher.shutdown() await self.dispatcher.shutdown()
await self.disconnect() await self.disconnect()
# Close backend
if hasattr(self, 'backend'): if hasattr(self, 'backend'):
self.backend.close() self.backend.close()
def stop(self): def stop(self):
self.running = False self.running = False
def parse_args():
def run():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="reverse-gateway", prog="reverse-gateway",
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge" description=__doc__,
) )
parser.add_argument(
'--id',
default='rev-gateway',
help='Service identifier (default: rev-gateway)',
)
parser.add_argument( parser.add_argument(
'--websocket-uri', '--websocket-uri',
default=None, default=None,
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)' help=f'WebSocket URI to connect to (default: {default_websocket})',
) )
parser.add_argument( parser.add_argument(
'--max-workers', '--max-workers',
type=int, type=int,
default=10, default=10,
help='Maximum concurrent message handlers (default: 10)' help='Maximum concurrent message handlers (default: 10)',
) )
parser.add_argument(
'-p', '--pulsar-host',
default=None,
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
)
parser.add_argument(
'--pulsar-api-key',
default=None,
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
)
parser.add_argument(
'--pulsar-listener',
default=None,
help='Pulsar listener name'
)
return parser.parse_args()
def run(): parser.add_argument(
args = parse_args() '--timeout',
type=int,
gateway = ReverseGateway( default=default_timeout,
websocket_uri=args.websocket_uri, help=f'Request timeout in seconds (default: {default_timeout})',
max_workers=args.max_workers,
pulsar_host=args.pulsar_host,
pulsar_api_key=args.pulsar_api_key,
pulsar_listener=args.pulsar_listener
) )
add_pubsub_args(parser)
add_logging_args(parser)
args = parser.parse_args()
args = vars(args)
setup_logging(args)
gateway = ReverseGateway(**args)
logger.info(f"Starting reverse gateway:") logger.info(f"Starting reverse gateway:")
logger.info(f" WebSocket URI: {gateway.url}") logger.info(f" WebSocket URI: {gateway.url}")
logger.info(f" Max workers: {args.max_workers}") logger.info(f" Max workers: {gateway.max_workers}")
logger.info(f" Pulsar host: {gateway.pulsar_host}")
try: try:
asyncio.run(gateway.run()) asyncio.run(gateway.run())
except KeyboardInterrupt: except KeyboardInterrupt: