mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-27 08:15:13 +02:00
Update rev-gateway for IAM integration (#940)
service.py: - Constructor takes **config (same pattern as api-gateway) instead of individual args - Creates IamAuth and calls await self.auth.start() before the message loop - Passes auth to both ConfigReceiver and MessageDispatcher - Uses add_pubsub_args / add_logging_args instead of hand-rolled Pulsar args - Passes timeout through dispatcher.py: - Accepts auth and timeout parameters - Passes both to DispatcherManager — fixes the missing auth argument that would have crashed on startup The remote end's requests now go through the same IAM authentication path as api-gateway. Token validation, workspace resolution, and permissions all work identically regardless of which direction initiated the connection. Fixed tests — the test now passes auth and timeout to MessageDispatcher and verifies they're forwarded to DispatcherManager. Update rev gateway dispatcher to align with IAM. A "token" parameter must be passed with each message. Fix websocket relay to align with rev-gateway changes, conforms to the api-gateway protocol.
This commit is contained in:
parent
4e3bd85abc
commit
e57f4669e1
7 changed files with 914 additions and 865 deletions
|
|
@ -3,208 +3,278 @@
|
|||
WebSocket Relay Test Harness
|
||||
|
||||
This script creates a relay server with two WebSocket endpoints:
|
||||
- /in - for test clients to connect to
|
||||
- /out - for reverse gateway to connect to
|
||||
- /in - for test clients to connect to (speaks api-gateway protocol)
|
||||
- /out - for reverse gateway to connect to (speaks rev-gateway protocol)
|
||||
|
||||
Messages are bidirectionally relayed between the two connections.
|
||||
Clients on /in authenticate with a first-frame auth message:
|
||||
{"type": "auth", "token": "..."}
|
||||
|
||||
The relay stores the token and injects it into each subsequent message
|
||||
before forwarding to /out. Responses from /out are forwarded back to
|
||||
the originating /in connection unchanged.
|
||||
|
||||
Usage:
|
||||
python websocket_relay.py [--port PORT] [--host HOST]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
from aiohttp import web, WSMsgType
|
||||
import weakref
|
||||
from typing import Optional, Set
|
||||
from typing import Dict, Optional
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("websocket_relay")
|
||||
|
||||
|
||||
class InConnection:
|
||||
def __init__(self, ws, conn_id):
|
||||
self.ws = ws
|
||||
self.conn_id = conn_id
|
||||
self.token: Optional[str] = None
|
||||
self.authenticated = False
|
||||
|
||||
|
||||
class WebSocketRelay:
|
||||
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.in_connections: Set = weakref.WeakSet()
|
||||
self.out_connections: Set = weakref.WeakSet()
|
||||
|
||||
self.in_connections: Dict[str, InConnection] = {}
|
||||
self.out_connections: set = set()
|
||||
self._conn_counter = 0
|
||||
|
||||
def _next_conn_id(self):
|
||||
self._conn_counter += 1
|
||||
return f"conn-{self._conn_counter}"
|
||||
|
||||
async def handle_in_connection(self, request):
|
||||
"""Handle incoming connections on /in endpoint"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.in_connections.add(ws)
|
||||
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
|
||||
conn_id = self._next_conn_id()
|
||||
conn = InConnection(ws, conn_id)
|
||||
self.in_connections[conn_id] = conn
|
||||
logger.info(
|
||||
f"New 'in' connection {conn_id}. "
|
||||
f"Total in: {len(self.in_connections)}, "
|
||||
f"out: {len(self.out_connections)}"
|
||||
)
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data = msg.data
|
||||
logger.info(f"IN → OUT: {data}")
|
||||
await self._forward_to_out(data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
data = msg.data
|
||||
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
|
||||
await self._forward_to_out(data, binary=True)
|
||||
await self._handle_in_message(conn, msg.data)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}")
|
||||
logger.error(
|
||||
f"WebSocket error on 'in' connection "
|
||||
f"{conn_id}: {ws.exception()}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'in' connection handler: {e}")
|
||||
logger.error(
|
||||
f"Error in 'in' connection {conn_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
del self.in_connections[conn_id]
|
||||
logger.info(
|
||||
f"'in' connection {conn_id} closed. "
|
||||
f"Remaining in: {len(self.in_connections)}, "
|
||||
f"out: {len(self.out_connections)}"
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
async def handle_out_connection(self, request):
|
||||
"""Handle outgoing connections on /out endpoint"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.out_connections.add(ws)
|
||||
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
|
||||
async def _handle_in_message(self, conn, data):
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data = msg.data
|
||||
logger.info(f"OUT → IN: {data}")
|
||||
await self._forward_to_in(data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
data = msg.data
|
||||
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
|
||||
await self._forward_to_in(data, binary=True)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'out' connection handler: {e}")
|
||||
finally:
|
||||
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def _forward_to_out(self, data, binary=False):
|
||||
"""Forward message from 'in' to all 'out' connections"""
|
||||
if not self.out_connections:
|
||||
logger.warning("No 'out' connections available to forward message")
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"{conn.conn_id}: received non-JSON message"
|
||||
)
|
||||
return
|
||||
|
||||
closed_connections = []
|
||||
|
||||
if isinstance(message, dict) and message.get("type") == "auth":
|
||||
conn.token = message.get("token", "")
|
||||
conn.authenticated = True
|
||||
logger.info(f"{conn.conn_id}: authenticated")
|
||||
await conn.ws.send_json({
|
||||
"type": "auth-ok",
|
||||
"workspace": "relayed",
|
||||
})
|
||||
return
|
||||
|
||||
if not conn.authenticated:
|
||||
await conn.ws.send_json({
|
||||
"error": {
|
||||
"message": "auth required",
|
||||
"type": "auth-required",
|
||||
},
|
||||
"complete": True,
|
||||
})
|
||||
return
|
||||
|
||||
message["token"] = conn.token
|
||||
message["_relay_conn"] = conn.conn_id
|
||||
|
||||
forwarded = json.dumps(message)
|
||||
logger.info(f"IN {conn.conn_id} → OUT: {forwarded}")
|
||||
await self._forward_to_out(forwarded)
|
||||
|
||||
async def handle_out_connection(self, request):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.out_connections.add(ws)
|
||||
logger.info(
|
||||
f"New 'out' connection. "
|
||||
f"Total in: {len(self.in_connections)}, "
|
||||
f"out: {len(self.out_connections)}"
|
||||
)
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await self._handle_out_message(msg.data)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(
|
||||
f"WebSocket error on 'out' connection: "
|
||||
f"{ws.exception()}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'out' connection: {e}")
|
||||
finally:
|
||||
self.out_connections.discard(ws)
|
||||
logger.info(
|
||||
f"'out' connection closed. "
|
||||
f"Remaining in: {len(self.in_connections)}, "
|
||||
f"out: {len(self.out_connections)}"
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
async def _handle_out_message(self, data):
|
||||
try:
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("OUT: received non-JSON message")
|
||||
return
|
||||
|
||||
conn_id = message.pop("_relay_conn", None)
|
||||
|
||||
forwarded = json.dumps(message)
|
||||
logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}")
|
||||
|
||||
if conn_id and conn_id in self.in_connections:
|
||||
conn = self.in_connections[conn_id]
|
||||
try:
|
||||
if not conn.ws.closed:
|
||||
await conn.ws.send_str(forwarded)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error forwarding to 'in' {conn_id}: {e}"
|
||||
)
|
||||
else:
|
||||
await self._broadcast_to_in(forwarded)
|
||||
|
||||
async def _broadcast_to_in(self, data):
|
||||
closed = []
|
||||
for conn_id, conn in list(self.in_connections.items()):
|
||||
try:
|
||||
if conn.ws.closed:
|
||||
closed.append(conn_id)
|
||||
continue
|
||||
await conn.ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error broadcasting to 'in' {conn_id}: {e}"
|
||||
)
|
||||
closed.append(conn_id)
|
||||
for conn_id in closed:
|
||||
self.in_connections.pop(conn_id, None)
|
||||
|
||||
async def _forward_to_out(self, data):
|
||||
closed = []
|
||||
for ws in list(self.out_connections):
|
||||
try:
|
||||
if ws.closed:
|
||||
closed_connections.append(ws)
|
||||
closed.append(ws)
|
||||
continue
|
||||
|
||||
if binary:
|
||||
await ws.send_bytes(data)
|
||||
else:
|
||||
await ws.send_str(data)
|
||||
await ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to 'out' connection: {e}")
|
||||
closed_connections.append(ws)
|
||||
|
||||
# Clean up closed connections
|
||||
for ws in closed_connections:
|
||||
if ws in self.out_connections:
|
||||
self.out_connections.discard(ws)
|
||||
|
||||
async def _forward_to_in(self, data, binary=False):
|
||||
"""Forward message from 'out' to all 'in' connections"""
|
||||
if not self.in_connections:
|
||||
logger.warning("No 'in' connections available to forward message")
|
||||
return
|
||||
|
||||
closed_connections = []
|
||||
for ws in list(self.in_connections):
|
||||
try:
|
||||
if ws.closed:
|
||||
closed_connections.append(ws)
|
||||
continue
|
||||
|
||||
if binary:
|
||||
await ws.send_bytes(data)
|
||||
else:
|
||||
await ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to 'in' connection: {e}")
|
||||
closed_connections.append(ws)
|
||||
|
||||
# Clean up closed connections
|
||||
for ws in closed_connections:
|
||||
if ws in self.in_connections:
|
||||
self.in_connections.discard(ws)
|
||||
logger.error(f"Error forwarding to 'out': {e}")
|
||||
closed.append(ws)
|
||||
for ws in closed:
|
||||
self.out_connections.discard(ws)
|
||||
|
||||
|
||||
async def create_app(relay):
|
||||
"""Create the web application with routes"""
|
||||
app = web.Application()
|
||||
|
||||
# Add routes
|
||||
app.router.add_get('/in', relay.handle_in_connection)
|
||||
|
||||
app.router.add_get('/in/api/v1/socket', relay.handle_in_connection)
|
||||
app.router.add_get('/out', relay.handle_out_connection)
|
||||
|
||||
# Add a simple status endpoint
|
||||
|
||||
async def status(request):
|
||||
status_info = {
|
||||
return web.json_response({
|
||||
'in_connections': len(relay.in_connections),
|
||||
'out_connections': len(relay.out_connections),
|
||||
'status': 'running'
|
||||
}
|
||||
return web.json_response(status_info)
|
||||
|
||||
'status': 'running',
|
||||
})
|
||||
|
||||
app.router.add_get('/status', status)
|
||||
app.router.add_get('/', status) # Root also shows status
|
||||
|
||||
app.router.add_get('/', status)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="WebSocket Relay Test Harness"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--host',
|
||||
'--host',
|
||||
default='localhost',
|
||||
help='Host to bind to (default: localhost)'
|
||||
help='Host to bind to (default: localhost)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
'--port',
|
||||
type=int,
|
||||
default=8080,
|
||||
help='Port to bind to (default: 8080)'
|
||||
help='Port to bind to (default: 8080)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
help='Enable verbose logging',
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
relay = WebSocketRelay()
|
||||
|
||||
|
||||
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
|
||||
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in")
|
||||
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket")
|
||||
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
|
||||
print(f" Status: http://{args.host}:{args.port}/status")
|
||||
print()
|
||||
print("Usage:")
|
||||
print(f" Test client connects to: ws://{args.host}:{args.port}/in")
|
||||
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out")
|
||||
|
||||
print("Client protocol (same as api-gateway):")
|
||||
print(' 1. Connect to /in/api/v1/socket')
|
||||
print(' 2. Send: {"type": "auth", "token": "tg_..."}')
|
||||
print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}')
|
||||
print(' 4. Send requests as normal')
|
||||
|
||||
web.run_app(create_app(relay), host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
|
|||
max_concurrent = 0
|
||||
processing_event = asyncio.Event()
|
||||
|
||||
async def slow_process(message):
|
||||
async def slow_process(message, sender):
|
||||
nonlocal concurrent_count, max_concurrent
|
||||
concurrent_count += 1
|
||||
max_concurrent = max(max_concurrent, concurrent_count)
|
||||
await asyncio.sleep(0.05)
|
||||
concurrent_count -= 1
|
||||
return {"id": message.get("id"), "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = slow_process
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
# Launch more tasks than max_workers
|
||||
messages = [
|
||||
{"id": f"msg-{i}", "service": "test", "request": {}}
|
||||
|
|
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
|
|||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(dispatcher.handle_message(m))
|
||||
asyncio.create_task(dispatcher.handle_message(m, sender))
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
|
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
|
|||
|
||||
original_process = dispatcher._process_message
|
||||
|
||||
async def tracking_process(message):
|
||||
async def tracking_process(message, sender):
|
||||
nonlocal task_was_tracked
|
||||
# During processing, our task should be in active_tasks
|
||||
if len(dispatcher.active_tasks) > 0:
|
||||
task_was_tracked = True
|
||||
return {"id": message.get("id"), "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = tracking_process
|
||||
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test", "service": "test", "request": {}}
|
||||
{"id": "test", "service": "test", "request": {}},
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
assert task_was_tracked
|
||||
|
|
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
|
|||
"""Semaphore should be released even if processing raises."""
|
||||
dispatcher = MessageDispatcher(max_workers=2)
|
||||
|
||||
async def failing_process(message):
|
||||
async def failing_process(message, sender):
|
||||
raise RuntimeError("process failed")
|
||||
|
||||
dispatcher._process_message = failing_process
|
||||
|
|
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
|
|||
# Should not deadlock — semaphore must be released on error
|
||||
with pytest.raises(RuntimeError):
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test", "service": "test", "request": {}}
|
||||
{"id": "test", "service": "test", "request": {}},
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
# Semaphore should be back at max
|
||||
|
|
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
|
|||
|
||||
order = []
|
||||
|
||||
async def ordered_process(message):
|
||||
async def ordered_process(message, sender):
|
||||
msg_id = message["id"]
|
||||
order.append(f"start-{msg_id}")
|
||||
await asyncio.sleep(0.02)
|
||||
order.append(f"end-{msg_id}")
|
||||
return {"id": msg_id, "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = ordered_process
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
|
||||
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
|
||||
tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# With semaphore=1, each message should complete before next starts
|
||||
|
|
|
|||
0
tests/unit/test_rev_gateway/__init__.py
Normal file
0
tests/unit/test_rev_gateway/__init__.py
Normal file
|
|
@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, ANY
|
||||
|
||||
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
|
||||
|
||||
|
||||
class TestWebSocketResponder:
|
||||
"""Test cases for WebSocketResponder class"""
|
||||
|
||||
def test_websocket_responder_initialization(self):
|
||||
"""Test WebSocketResponder initialization"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
assert responder.response is None
|
||||
assert responder.completed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_send_method(self):
|
||||
"""Test WebSocketResponder send method"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"data": "test response"}
|
||||
|
||||
# Call send method
|
||||
await responder.send(test_response)
|
||||
|
||||
# Verify response was stored
|
||||
assert responder.response == test_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_call_method(self):
|
||||
"""Test WebSocketResponder __call__ method"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"result": "success"}
|
||||
test_completed = True
|
||||
|
||||
# Call the responder
|
||||
await responder(test_response, test_completed)
|
||||
|
||||
# Verify response and completed status were set
|
||||
assert responder.response == test_response
|
||||
assert responder.completed == test_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_call_method_with_false_completion(self):
|
||||
"""Test WebSocketResponder __call__ method with incomplete response"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"partial": "data"}
|
||||
test_completed = False
|
||||
|
||||
# Call the responder
|
||||
await responder(test_response, test_completed)
|
||||
|
||||
# Verify response was set and completed is True (since send() always sets completed=True)
|
||||
assert responder.response == test_response
|
||||
assert responder.completed is True
|
||||
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
|
||||
|
||||
|
||||
class TestMessageDispatcher:
|
||||
"""Test cases for MessageDispatcher class"""
|
||||
|
||||
def test_message_dispatcher_initialization_with_defaults(self):
|
||||
"""Test MessageDispatcher initialization with default parameters"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
|
||||
assert dispatcher.max_workers == 10
|
||||
assert dispatcher.semaphore._value == 10
|
||||
assert dispatcher.active_tasks == set()
|
||||
assert dispatcher.backend is None
|
||||
assert dispatcher.auth is None
|
||||
assert dispatcher.dispatcher_manager is None
|
||||
assert len(dispatcher.service_mapping) > 0
|
||||
|
||||
def test_message_dispatcher_initialization_with_custom_workers(self):
|
||||
"""Test MessageDispatcher initialization with custom max_workers"""
|
||||
dispatcher = MessageDispatcher(max_workers=5)
|
||||
|
||||
|
||||
assert dispatcher.max_workers == 5
|
||||
assert dispatcher.semaphore._value == 5
|
||||
|
||||
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
||||
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
|
||||
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
|
||||
def test_message_dispatcher_initialization_with_backend(
|
||||
self, mock_dispatcher_manager,
|
||||
):
|
||||
mock_backend = MagicMock()
|
||||
mock_config_receiver = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher_instance = MagicMock()
|
||||
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
||||
|
||||
|
||||
dispatcher = MessageDispatcher(
|
||||
max_workers=8,
|
||||
config_receiver=mock_config_receiver,
|
||||
backend=mock_backend
|
||||
backend=mock_backend,
|
||||
auth=mock_auth,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
|
||||
assert dispatcher.max_workers == 8
|
||||
assert dispatcher.backend == mock_backend
|
||||
assert dispatcher.auth == mock_auth
|
||||
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
||||
mock_dispatcher_manager.assert_called_once_with(
|
||||
mock_backend, mock_config_receiver, prefix="rev-gateway"
|
||||
mock_backend, mock_config_receiver,
|
||||
auth=mock_auth, prefix="rev-gateway", timeout=300,
|
||||
)
|
||||
|
||||
def test_message_dispatcher_service_mapping(self):
|
||||
"""Test MessageDispatcher service mapping contains expected services"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
|
||||
expected_services = [
|
||||
"text-completion", "graph-rag", "agent", "embeddings",
|
||||
"graph-embeddings", "triples", "document-load", "text-load",
|
||||
"flow", "knowledge", "config", "librarian", "document-rag"
|
||||
"flow", "knowledge", "config", "librarian", "document-rag",
|
||||
]
|
||||
|
||||
|
||||
for service in expected_services:
|
||||
assert service in dispatcher.service_mapping
|
||||
|
||||
# Test specific mappings
|
||||
assert dispatcher.service_mapping["text-completion"] == "text-completion"
|
||||
|
||||
assert dispatcher.service_mapping["document-load"] == "document"
|
||||
assert dispatcher.service_mapping["text-load"] == "text-document"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
|
||||
"""Test MessageDispatcher handle_message without dispatcher manager"""
|
||||
async def test_handle_message_without_dispatcher_manager(self):
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
test_message = {
|
||||
"id": "test-123",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"}
|
||||
}
|
||||
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-123"
|
||||
assert "error" in result["response"]
|
||||
assert "DispatcherManager not available" in result["response"]["error"]
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="default")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test-1", "service": "test", "request": {}},
|
||||
sender,
|
||||
)
|
||||
|
||||
sender.assert_called_once()
|
||||
sent = sender.call_args[0][0]
|
||||
assert sent["id"] == "test-1"
|
||||
assert sent["error"]["message"] == "DispatcherManager not available"
|
||||
assert sent["error"]["type"] == "error"
|
||||
assert sent["complete"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_with_exception(self):
|
||||
"""Test MessageDispatcher handle_message with exception during processing"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
async def test_handle_message_auth_failure(self):
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-456",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "test"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-456"
|
||||
assert "error" in result["response"]
|
||||
assert "Test error" in result["response"]["error"]
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
side_effect=Exception("auth failure")
|
||||
)
|
||||
dispatcher.dispatcher_manager = MagicMock()
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test-2", "token": "bad", "service": "test", "request": {}},
|
||||
sender,
|
||||
)
|
||||
|
||||
sender.assert_called_once()
|
||||
sent = sender.call_args[0][0]
|
||||
assert sent["id"] == "test-2"
|
||||
assert "auth failure" in sent["error"]["message"]
|
||||
assert sent["complete"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_global_service(self):
|
||||
"""Test MessageDispatcher handle_message with global service"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_global_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = True
|
||||
mock_responder.response = {"result": "success"}
|
||||
|
||||
async def test_handle_message_global_service(self):
|
||||
mock_dm = MagicMock()
|
||||
mock_dm.invoke_global_service = AsyncMock()
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-789",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "hello"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-789"
|
||||
assert result["response"] == {"result": "success"}
|
||||
mock_dispatcher_manager.invoke_global_service.assert_called_once()
|
||||
dispatcher.dispatcher_manager = mock_dm
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="ws1")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'trustgraph.gateway.dispatch.manager.global_dispatchers',
|
||||
{"text-completion": True},
|
||||
):
|
||||
await dispatcher.handle_message(
|
||||
{
|
||||
"id": "test-3",
|
||||
"token": "tg_key",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "hello"},
|
||||
},
|
||||
sender,
|
||||
)
|
||||
|
||||
mock_dm.invoke_global_service.assert_called_once()
|
||||
args, kwargs = mock_dm.invoke_global_service.call_args
|
||||
assert args[0] == {"prompt": "hello"}
|
||||
assert args[2] == "text-completion"
|
||||
assert kwargs["workspace"] == "ws1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_flow_service(self):
|
||||
"""Test MessageDispatcher handle_message with flow service"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = True
|
||||
mock_responder.response = {"data": "flow_result"}
|
||||
|
||||
async def test_handle_message_flow_service(self):
|
||||
mock_dm = MagicMock()
|
||||
mock_dm.invoke_flow_service = AsyncMock()
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-flow-123",
|
||||
"service": "document-rag",
|
||||
"request": {"query": "test"},
|
||||
"flow": "custom-flow"
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-flow-123"
|
||||
assert result["response"] == {"data": "flow_result"}
|
||||
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
|
||||
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
|
||||
dispatcher.dispatcher_manager = mock_dm
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="ws2")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
|
||||
):
|
||||
await dispatcher.handle_message(
|
||||
{
|
||||
"id": "test-4",
|
||||
"token": "tg_key",
|
||||
"service": "document-rag",
|
||||
"request": {"query": "test"},
|
||||
"flow": "my-flow",
|
||||
},
|
||||
sender,
|
||||
)
|
||||
|
||||
mock_dm.invoke_flow_service.assert_called_once_with(
|
||||
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_incomplete_response(self):
|
||||
"""Test MessageDispatcher handle_message with incomplete response"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = False
|
||||
mock_responder.response = None
|
||||
|
||||
async def test_handle_message_responder_sends_frames(self):
|
||||
mock_dm = MagicMock()
|
||||
|
||||
async def fake_invoke(data, responder, svc, workspace=None):
|
||||
await responder({"partial": 1}, False)
|
||||
await responder({"partial": 2}, True)
|
||||
|
||||
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-incomplete",
|
||||
"service": "agent",
|
||||
"request": {"input": "test"}
|
||||
dispatcher.dispatcher_manager = mock_dm
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="ws1")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'trustgraph.gateway.dispatch.manager.global_dispatchers',
|
||||
{"text-completion": True},
|
||||
):
|
||||
await dispatcher.handle_message(
|
||||
{
|
||||
"id": "test-5",
|
||||
"token": "tg_key",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "hi"},
|
||||
},
|
||||
sender,
|
||||
)
|
||||
|
||||
assert sender.call_count == 2
|
||||
first = sender.call_args_list[0][0][0]
|
||||
second = sender.call_args_list[1][0][0]
|
||||
|
||||
assert first == {
|
||||
"id": "test-5", "response": {"partial": 1}, "complete": False,
|
||||
}
|
||||
assert second == {
|
||||
"id": "test-5", "response": {"partial": 2}, "complete": True,
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-incomplete"
|
||||
assert result["response"] == {"error": "No response received"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_shutdown(self):
|
||||
"""Test MessageDispatcher shutdown method"""
|
||||
import asyncio
|
||||
|
||||
async def test_handle_message_workspace_from_identity(self):
|
||||
mock_dm = MagicMock()
|
||||
mock_dm.invoke_flow_service = AsyncMock()
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
# Create actual async tasks
|
||||
dispatcher.dispatcher_manager = mock_dm
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="derived-ws")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
|
||||
):
|
||||
await dispatcher.handle_message(
|
||||
{
|
||||
"id": "test-6",
|
||||
"token": "tg_key",
|
||||
"service": "agent",
|
||||
"request": {"question": "test"},
|
||||
"flow": "default",
|
||||
},
|
||||
sender,
|
||||
)
|
||||
|
||||
args = mock_dm.invoke_flow_service.call_args[0]
|
||||
assert args[2] == "derived-ws"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown(self):
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
async def dummy_task():
|
||||
await asyncio.sleep(0.01)
|
||||
return "done"
|
||||
|
||||
|
||||
task1 = asyncio.create_task(dummy_task())
|
||||
task2 = asyncio.create_task(dummy_task())
|
||||
dispatcher.active_tasks = {task1, task2}
|
||||
|
||||
# Call shutdown
|
||||
|
||||
await dispatcher.shutdown()
|
||||
|
||||
# Verify tasks were completed
|
||||
|
||||
assert task1.done()
|
||||
assert task2.done()
|
||||
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_shutdown_with_no_tasks(self):
|
||||
"""Test MessageDispatcher shutdown with no active tasks"""
|
||||
async def test_shutdown_with_no_tasks(self):
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
# Call shutdown with no active tasks
|
||||
|
||||
await dispatcher.shutdown()
|
||||
|
||||
# Should complete without error
|
||||
assert dispatcher.active_tasks == set()
|
||||
|
||||
assert dispatcher.active_tasks == set()
|
||||
|
|
|
|||
|
|
@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock
|
|||
from aiohttp import WSMsgType, ClientWebSocketResponse
|
||||
import json
|
||||
|
||||
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run
|
||||
from trustgraph.rev_gateway.service import ReverseGateway, run
|
||||
|
||||
|
||||
MOCK_PATCHES = [
|
||||
'trustgraph.rev_gateway.service.IamAuth',
|
||||
'trustgraph.rev_gateway.service.ConfigReceiver',
|
||||
'trustgraph.rev_gateway.service.MessageDispatcher',
|
||||
'trustgraph.rev_gateway.service.get_pubsub',
|
||||
]
|
||||
|
||||
|
||||
def make_gateway(**overrides):
|
||||
config = {"websocket_uri": "ws://localhost:7650/out"}
|
||||
config.update(overrides)
|
||||
return ReverseGateway(**config)
|
||||
|
||||
|
||||
class TestReverseGateway:
|
||||
"""Test cases for ReverseGateway class"""
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with default parameters"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_defaults(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
assert gateway.websocket_uri == "ws://localhost:7650/out"
|
||||
assert gateway.host == "localhost"
|
||||
assert gateway.port == 7650
|
||||
|
|
@ -33,25 +49,22 @@ class TestReverseGateway:
|
|||
assert gateway.max_workers == 10
|
||||
assert gateway.running is False
|
||||
assert gateway.reconnect_delay == 3.0
|
||||
assert gateway.pulsar_host == "pulsar://pulsar:6650"
|
||||
assert gateway.pulsar_api_key is None
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with custom parameters"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_custom_params(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway(
|
||||
websocket_uri="wss://example.com:8080/websocket",
|
||||
max_workers=20,
|
||||
pulsar_host="pulsar://custom:6650",
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
|
||||
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
|
||||
assert gateway.host == "example.com"
|
||||
assert gateway.port == 8080
|
||||
|
|
@ -59,340 +72,360 @@ class TestReverseGateway:
|
|||
assert gateway.path == "/websocket"
|
||||
assert gateway.url == "wss://example.com:8080/websocket"
|
||||
assert gateway.max_workers == 20
|
||||
assert gateway.pulsar_host == "pulsar://custom:6650"
|
||||
assert gateway.pulsar_api_key == "test-key"
|
||||
assert gateway.pulsar_listener == "test-listener"
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with WebSocket URI missing path"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(websocket_uri="ws://example.com")
|
||||
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_with_missing_path(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway(websocket_uri="ws://example.com")
|
||||
|
||||
assert gateway.path == "/ws"
|
||||
assert gateway.url == "ws://example.com/ws"
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_invalid_scheme(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
||||
ReverseGateway(websocket_uri="http://example.com")
|
||||
make_gateway(websocket_uri="http://example.com")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with missing hostname"""
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_missing_hostname(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
||||
ReverseGateway(websocket_uri="ws://")
|
||||
make_gateway(websocket_uri="ws://")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway creates backend with authentication"""
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_iam_auth_created(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
gateway = make_gateway(id="test-rev-gw")
|
||||
|
||||
mock_iam_auth.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
id="test-rev-gw",
|
||||
)
|
||||
|
||||
# Verify get_pubsub was called with the correct parameters
|
||||
mock_get_pubsub.assert_called_once_with(
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_config_receiver_gets_auth(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_iam_auth.return_value = mock_auth_instance
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
mock_config_receiver.assert_called_once_with(
|
||||
mock_backend, auth=mock_auth_instance,
|
||||
)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway successful connection"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
async def test_reverse_gateway_connect_success(
|
||||
self, mock_session_class, mock_get_pubsub,
|
||||
mock_dispatcher, mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_session.ws_connect.return_value = mock_ws
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
result = await gateway.connect()
|
||||
|
||||
|
||||
assert result is True
|
||||
assert gateway.session == mock_session
|
||||
assert gateway.ws == mock_ws
|
||||
mock_session.ws_connect.assert_called_once_with(gateway.url)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway connection failure"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
async def test_reverse_gateway_connect_failure(
|
||||
self, mock_session_class, mock_get_pubsub,
|
||||
mock_dispatcher, mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
result = await gateway.connect()
|
||||
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway disconnect"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock websocket and session
|
||||
async def test_reverse_gateway_disconnect(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
|
||||
|
||||
gateway.ws = mock_ws
|
||||
gateway.session = mock_session
|
||||
|
||||
|
||||
await gateway.disconnect()
|
||||
|
||||
|
||||
mock_ws.close.assert_called_once()
|
||||
mock_session.close.assert_called_once()
|
||||
assert gateway.ws is None
|
||||
assert gateway.session is None
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock websocket
|
||||
async def test_reverse_gateway_send_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
|
||||
test_message = {"id": "test", "data": "hello"}
|
||||
|
||||
|
||||
await gateway.send_message(test_message)
|
||||
|
||||
|
||||
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message with closed connection"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock closed websocket
|
||||
async def test_reverse_gateway_send_message_closed_connection(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = True
|
||||
gateway.ws = mock_ws
|
||||
|
||||
|
||||
test_message = {"id": "test", "data": "hello"}
|
||||
|
||||
|
||||
await gateway.send_message(test_message)
|
||||
|
||||
# Should not call send_str on closed connection
|
||||
|
||||
mock_ws.send_str.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock send_message
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
|
||||
|
||||
await gateway.handle_message(test_message)
|
||||
|
||||
mock_dispatcher_instance.handle_message.assert_called_once_with({
|
||||
"id": "test",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"}
|
||||
})
|
||||
gateway.send_message.assert_called_once_with({"response": "success"})
|
||||
async def test_reverse_gateway_handle_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message with invalid JSON"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock send_message
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
test_message = 'invalid json'
|
||||
|
||||
# Should not raise exception
|
||||
|
||||
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
|
||||
|
||||
await gateway.handle_message(test_message)
|
||||
|
||||
# Should not call send_message due to error
|
||||
|
||||
mock_dispatcher_instance.handle_message.assert_called_once_with(
|
||||
{
|
||||
"id": "test",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"},
|
||||
},
|
||||
gateway.send_message,
|
||||
)
|
||||
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message_invalid_json(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
await gateway.handle_message('invalid json')
|
||||
|
||||
gateway.send_message.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with text message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
async def test_reverse_gateway_listen_text_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.TEXT
|
||||
mock_msg.data = '{"test": "message"}'
|
||||
|
||||
# Mock receive to return message once, then raise exception to stop loop
|
||||
|
||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||
|
||||
# listen() catches exceptions and breaks, so no exception should be raised
|
||||
|
||||
await gateway.listen()
|
||||
|
||||
|
||||
gateway.handle_message.assert_called_once_with('{"test": "message"}')
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with binary message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
async def test_reverse_gateway_listen_binary_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.BINARY
|
||||
mock_msg.data = b'{"test": "binary"}'
|
||||
|
||||
# Mock receive to return message once, then raise exception to stop loop
|
||||
|
||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||
|
||||
# listen() catches exceptions and breaks, so no exception should be raised
|
||||
|
||||
await gateway.listen()
|
||||
|
||||
|
||||
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with close message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
async def test_reverse_gateway_listen_close_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.CLOSE
|
||||
|
||||
# Mock receive to return close message
|
||||
|
||||
mock_ws.receive.return_value = mock_msg
|
||||
|
||||
|
||||
await gateway.listen()
|
||||
|
||||
# Should not call handle_message for close message
|
||||
|
||||
gateway.handle_message.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway shutdown"""
|
||||
async def test_reverse_gateway_shutdown(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock disconnect
|
||||
gateway.disconnect = AsyncMock()
|
||||
|
||||
await gateway.shutdown()
|
||||
|
|
@ -402,46 +435,50 @@ class TestReverseGateway:
|
|||
gateway.disconnect.assert_called_once()
|
||||
mock_backend.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway stop"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_stop(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
|
||||
gateway.stop()
|
||||
|
||||
|
||||
assert gateway.running is False
|
||||
|
||||
|
||||
class TestReverseGatewayRun:
|
||||
"""Test cases for ReverseGateway run method"""
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway run method with successful connect/listen cycle"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
async def test_reverse_gateway_run_successful_cycle(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
mock_auth_instance = AsyncMock()
|
||||
mock_iam_auth.return_value = mock_auth_instance
|
||||
|
||||
mock_config_receiver_instance = AsyncMock()
|
||||
mock_config_receiver.return_value = mock_config_receiver_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock methods
|
||||
gateway.connect = AsyncMock(return_value=True)
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
gateway.listen = AsyncMock()
|
||||
gateway.disconnect = AsyncMock()
|
||||
gateway.shutdown = AsyncMock()
|
||||
|
||||
# Stop after one iteration
|
||||
|
||||
call_count = 0
|
||||
async def mock_connect():
|
||||
nonlocal call_count
|
||||
|
|
@ -451,91 +488,13 @@ class TestReverseGatewayRun:
|
|||
else:
|
||||
gateway.running = False
|
||||
return False
|
||||
|
||||
|
||||
gateway.connect = mock_connect
|
||||
|
||||
|
||||
await gateway.run()
|
||||
|
||||
|
||||
mock_auth_instance.start.assert_called_once()
|
||||
mock_config_receiver_instance.start.assert_called_once()
|
||||
gateway.listen.assert_called_once()
|
||||
# disconnect is called twice: once in the main loop, once in shutdown
|
||||
assert gateway.disconnect.call_count == 2
|
||||
gateway.shutdown.assert_called_once()
|
||||
|
||||
|
||||
class TestReverseGatewayArgs:
|
||||
"""Test cases for argument parsing and run function"""
|
||||
|
||||
def test_parse_args_defaults(self):
|
||||
"""Test parse_args with default values"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = ['reverse-gateway']
|
||||
|
||||
try:
|
||||
args = parse_args()
|
||||
|
||||
assert args.websocket_uri is None
|
||||
assert args.max_workers == 10
|
||||
assert args.pulsar_host is None
|
||||
assert args.pulsar_api_key is None
|
||||
assert args.pulsar_listener is None
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
def test_parse_args_custom_values(self):
|
||||
"""Test parse_args with custom values"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = [
|
||||
'reverse-gateway',
|
||||
'--websocket-uri', 'ws://custom:8080/ws',
|
||||
'--max-workers', '20',
|
||||
'--pulsar-host', 'pulsar://custom:6650',
|
||||
'--pulsar-api-key', 'test-key',
|
||||
'--pulsar-listener', 'test-listener'
|
||||
]
|
||||
|
||||
try:
|
||||
args = parse_args()
|
||||
|
||||
assert args.websocket_uri == 'ws://custom:8080/ws'
|
||||
assert args.max_workers == 20
|
||||
assert args.pulsar_host == 'pulsar://custom:6650'
|
||||
assert args.pulsar_api_key == 'test-key'
|
||||
assert args.pulsar_listener == 'test-listener'
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ReverseGateway')
|
||||
@patch('asyncio.run')
|
||||
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
|
||||
"""Test run function"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = ['reverse-gateway', '--max-workers', '15']
|
||||
|
||||
try:
|
||||
mock_gateway_instance = MagicMock()
|
||||
mock_gateway_instance.url = "ws://localhost:7650/out"
|
||||
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
|
||||
mock_gateway_class.return_value = mock_gateway_instance
|
||||
|
||||
run()
|
||||
|
||||
mock_gateway_class.assert_called_once_with(
|
||||
websocket_uri=None,
|
||||
max_workers=15,
|
||||
pulsar_host=None,
|
||||
pulsar_api_key=None,
|
||||
pulsar_listener=None
|
||||
)
|
||||
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
|
@ -1,130 +1,140 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from trustgraph.messaging import TranslatorRegistry
|
||||
from typing import Dict, Any, Optional, Callable, Awaitable
|
||||
from ..gateway.dispatch.manager import DispatcherManager
|
||||
|
||||
logger = logging.getLogger("dispatcher")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class WebSocketResponder:
|
||||
"""Simple responder that captures response for websocket return"""
|
||||
def __init__(self):
|
||||
self.response = None
|
||||
self.completed = False
|
||||
|
||||
async def send(self, data):
|
||||
"""Capture the response data"""
|
||||
self.response = data
|
||||
self.completed = True
|
||||
|
||||
async def __call__(self, data, final=False):
|
||||
"""Make the responder callable for compatibility with requestor"""
|
||||
await self.send(data)
|
||||
if final:
|
||||
self.completed = True
|
||||
|
||||
class _TokenShim:
|
||||
def __init__(self, token):
|
||||
self.headers = (
|
||||
{"Authorization": f"Bearer {token}"} if token else {}
|
||||
)
|
||||
|
||||
|
||||
class MessageDispatcher:
|
||||
|
||||
def __init__(self, max_workers: int = 10, config_receiver=None, backend=None):
|
||||
def __init__(self, max_workers=10, config_receiver=None, backend=None,
|
||||
auth=None, timeout=120):
|
||||
self.max_workers = max_workers
|
||||
self.semaphore = asyncio.Semaphore(max_workers)
|
||||
self.active_tasks = set()
|
||||
self.backend = backend
|
||||
self.auth = auth
|
||||
|
||||
# Use DispatcherManager for flow and service management
|
||||
if backend and config_receiver:
|
||||
self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway")
|
||||
if backend and config_receiver and auth:
|
||||
self.dispatcher_manager = DispatcherManager(
|
||||
backend, config_receiver,
|
||||
auth=auth,
|
||||
prefix="rev-gateway",
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
self.dispatcher_manager = None
|
||||
logger.warning("No backend or config_receiver provided - using fallback mode")
|
||||
|
||||
# Service name mapping from websocket protocol to translator registry
|
||||
logger.warning(
|
||||
"Missing backend, config_receiver, or auth "
|
||||
"— using fallback mode"
|
||||
)
|
||||
|
||||
self.service_mapping = {
|
||||
"text-completion": "text-completion",
|
||||
"graph-rag": "graph-rag",
|
||||
"graph-rag": "graph-rag",
|
||||
"agent": "agent",
|
||||
"embeddings": "embeddings",
|
||||
"graph-embeddings": "graph-embeddings",
|
||||
"triples": "triples",
|
||||
"document-load": "document",
|
||||
"text-load": "text-document",
|
||||
"text-load": "text-document",
|
||||
"flow": "flow",
|
||||
"knowledge": "knowledge",
|
||||
"config": "config",
|
||||
"librarian": "librarian",
|
||||
"document-rag": "document-rag"
|
||||
"document-rag": "document-rag",
|
||||
}
|
||||
|
||||
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
||||
|
||||
async def handle_message(
|
||||
self, message: Dict[Any, Any],
|
||||
sender: Callable[[dict], Awaitable[None]],
|
||||
):
|
||||
async with self.semaphore:
|
||||
task = asyncio.create_task(self._process_message(message))
|
||||
task = asyncio.create_task(
|
||||
self._process_message(message, sender)
|
||||
)
|
||||
self.active_tasks.add(task)
|
||||
|
||||
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
await task
|
||||
finally:
|
||||
self.active_tasks.discard(task)
|
||||
|
||||
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
|
||||
async def _authenticate(self, token):
|
||||
if not self.auth:
|
||||
raise RuntimeError("Auth not configured")
|
||||
return await self.auth.authenticate(_TokenShim(token))
|
||||
|
||||
async def _process_message(
|
||||
self, message: Dict[Any, Any],
|
||||
sender: Callable[[dict], Awaitable[None]],
|
||||
):
|
||||
request_id = message.get('id', str(uuid.uuid4()))
|
||||
service = message.get('service')
|
||||
request_data = message.get('request', {})
|
||||
flow_id = message.get('flow', 'default') # Default flow
|
||||
|
||||
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
|
||||
|
||||
token = message.get('token', '')
|
||||
flow_id = message.get('flow', 'default')
|
||||
|
||||
logger.info(
|
||||
f"Processing message {request_id} for service "
|
||||
f"{service} on flow {flow_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.dispatcher_manager:
|
||||
raise RuntimeError("DispatcherManager not available - backend and config_receiver required")
|
||||
|
||||
# Use DispatcherManager for flow-based processing
|
||||
responder = WebSocketResponder()
|
||||
|
||||
# Map websocket service name to dispatcher service name
|
||||
raise RuntimeError(
|
||||
"DispatcherManager not available"
|
||||
)
|
||||
|
||||
identity = await self._authenticate(token)
|
||||
workspace = identity.workspace
|
||||
|
||||
async def responder(resp, fin):
|
||||
await sender({
|
||||
"id": request_id,
|
||||
"response": resp,
|
||||
"complete": fin,
|
||||
})
|
||||
|
||||
dispatcher_service = self.service_mapping.get(service, service)
|
||||
|
||||
# Check if this is a global service or flow service
|
||||
|
||||
from ..gateway.dispatch.manager import global_dispatchers
|
||||
if dispatcher_service in global_dispatchers:
|
||||
# Use global service dispatcher
|
||||
await self.dispatcher_manager.invoke_global_service(
|
||||
request_data, responder, dispatcher_service
|
||||
request_data, responder, dispatcher_service,
|
||||
workspace=workspace,
|
||||
)
|
||||
else:
|
||||
# Use DispatcherManager to process the request through Pulsar queues
|
||||
await self.dispatcher_manager.invoke_flow_service(
|
||||
request_data, responder, flow_id, dispatcher_service
|
||||
request_data, responder, workspace, flow_id,
|
||||
dispatcher_service,
|
||||
)
|
||||
|
||||
# Get the response from the responder
|
||||
if responder.completed:
|
||||
response_data = responder.response
|
||||
else:
|
||||
response_data = {'error': 'No response received'}
|
||||
|
||||
response = {
|
||||
'id': request_id,
|
||||
'response': response_data
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message {request_id}: {e}")
|
||||
response = {
|
||||
'id': request_id,
|
||||
'response': {'error': str(e)}
|
||||
}
|
||||
|
||||
await sender({
|
||||
"id": request_id,
|
||||
"error": {"message": str(e), "type": "error"},
|
||||
"complete": True,
|
||||
})
|
||||
|
||||
logger.info(f"Completed processing message {request_id}")
|
||||
return response
|
||||
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
if self.active_tasks:
|
||||
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
|
||||
logger.info(
|
||||
f"Waiting for {len(self.active_tasks)} active "
|
||||
f"tasks to complete"
|
||||
)
|
||||
await asyncio.gather(*self.active_tasks, return_exceptions=True)
|
||||
|
||||
# DispatcherManager handles its own cleanup
|
||||
|
||||
logger.info("Dispatcher shutdown complete")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
"""
|
||||
Reverse gateway. Initiates outbound WebSocket connections to a remote
|
||||
relay and dispatches incoming requests through the same DispatcherManager
|
||||
pipeline as api-gateway.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
|
|
@ -9,82 +15,86 @@ from typing import Optional
|
|||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from .dispatcher import MessageDispatcher
|
||||
from ..gateway.auth import IamAuth
|
||||
from ..gateway.config.receiver import ConfigReceiver
|
||||
from ..base import get_pubsub
|
||||
from ..base.pubsub import get_pubsub, add_pubsub_args
|
||||
from ..base.logging import setup_logging, add_logging_args
|
||||
|
||||
logger = logging.getLogger("rev_gateway")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
default_websocket = "ws://localhost:7650/out"
|
||||
default_timeout = 600
|
||||
|
||||
class ReverseGateway:
|
||||
|
||||
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
|
||||
pulsar_host: str = None, pulsar_api_key: str = None,
|
||||
pulsar_listener: str = None):
|
||||
# Set default WebSocket URI with environment variable support
|
||||
|
||||
def __init__(self, **config):
|
||||
websocket_uri = config.get("websocket_uri")
|
||||
if websocket_uri is None:
|
||||
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
|
||||
|
||||
# Parse and validate the WebSocket URI
|
||||
|
||||
parsed_uri = urlparse(websocket_uri)
|
||||
if parsed_uri.scheme not in ('ws', 'wss'):
|
||||
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
|
||||
raise ValueError(
|
||||
f"WebSocket URI must use ws:// or wss:// scheme, "
|
||||
f"got: {parsed_uri.scheme}"
|
||||
)
|
||||
if not parsed_uri.netloc:
|
||||
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
|
||||
|
||||
# Store parsed components for debugging/logging
|
||||
raise ValueError(
|
||||
f"WebSocket URI must include hostname, "
|
||||
f"got: {websocket_uri}"
|
||||
)
|
||||
|
||||
self.websocket_uri = websocket_uri
|
||||
self.host = parsed_uri.hostname
|
||||
self.port = parsed_uri.port
|
||||
self.scheme = parsed_uri.scheme
|
||||
self.path = parsed_uri.path or "/ws"
|
||||
|
||||
# Construct the full URL (in case path was missing)
|
||||
|
||||
if not parsed_uri.path:
|
||||
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
|
||||
else:
|
||||
self.url = websocket_uri
|
||||
|
||||
self.max_workers = max_workers
|
||||
|
||||
self.max_workers = int(config.get("max_workers", 10))
|
||||
self.timeout = int(config.get("timeout", default_timeout))
|
||||
self.ws: Optional[ClientWebSocketResponse] = None
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.running = False
|
||||
self.reconnect_delay = 3.0
|
||||
|
||||
# Pulsar configuration
|
||||
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
|
||||
self.pulsar_listener = pulsar_listener
|
||||
|
||||
# Create backend using factory
|
||||
backend_params = {
|
||||
'pulsar_host': self.pulsar_host,
|
||||
'pulsar_api_key': self.pulsar_api_key,
|
||||
'pulsar_listener': self.pulsar_listener,
|
||||
}
|
||||
self.backend = get_pubsub(**backend_params)
|
||||
self.backend = get_pubsub(**config)
|
||||
|
||||
# Initialize config receiver
|
||||
self.config_receiver = ConfigReceiver(self.backend)
|
||||
self.auth = IamAuth(
|
||||
backend=self.backend,
|
||||
id=config.get("id", "rev-gateway"),
|
||||
)
|
||||
|
||||
self.config_receiver = ConfigReceiver(
|
||||
self.backend, auth=self.auth,
|
||||
)
|
||||
|
||||
self.dispatcher = MessageDispatcher(
|
||||
self.max_workers, self.config_receiver, self.backend,
|
||||
auth=self.auth, timeout=self.timeout,
|
||||
)
|
||||
|
||||
# Initialize dispatcher with config_receiver and backend - must be created after config_receiver
|
||||
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
try:
|
||||
if self.session is None:
|
||||
self.session = ClientSession()
|
||||
|
||||
|
||||
logger.info(f"Connecting to {self.url}")
|
||||
self.ws = await self.session.ws_connect(self.url)
|
||||
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
|
||||
logger.info(
|
||||
f"WebSocket connection established to "
|
||||
f"{self.host}:{self.port or 'default'}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {self.url}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
|
@ -92,32 +102,31 @@ class ReverseGateway:
|
|||
await self.session.close()
|
||||
self.ws = None
|
||||
self.session = None
|
||||
|
||||
|
||||
async def send_message(self, message: dict):
|
||||
if self.ws and not self.ws.closed:
|
||||
try:
|
||||
await self.ws.send_str(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message: {e}")
|
||||
|
||||
|
||||
async def handle_message(self, message: str):
|
||||
try:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
|
||||
msg_data = json.loads(message)
|
||||
response = await self.dispatcher.handle_message(msg_data)
|
||||
|
||||
if response:
|
||||
await self.send_message(response)
|
||||
|
||||
await self.dispatcher.handle_message(
|
||||
msg_data, self.send_message,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
|
||||
|
||||
async def listen(self):
|
||||
while self.running and self.ws and not self.ws.closed:
|
||||
try:
|
||||
msg = await self.ws.receive()
|
||||
|
||||
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await self.handle_message(msg.data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
|
|
@ -125,31 +134,33 @@ class ReverseGateway:
|
|||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
|
||||
logger.warning("WebSocket closed or error occurred")
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen loop: {e}")
|
||||
break
|
||||
|
||||
|
||||
async def run(self):
|
||||
self.running = True
|
||||
logger.info("Starting reverse gateway")
|
||||
|
||||
# Start config receiver
|
||||
logger.info("Starting config receiver")
|
||||
|
||||
await self.auth.start()
|
||||
await self.config_receiver.start()
|
||||
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if await self.connect():
|
||||
await self.listen()
|
||||
else:
|
||||
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
|
||||
|
||||
logger.warning(
|
||||
f"Connection failed, retrying in "
|
||||
f"{self.reconnect_delay} seconds"
|
||||
)
|
||||
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
if self.running:
|
||||
await asyncio.sleep(self.reconnect_delay)
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutdown requested")
|
||||
break
|
||||
|
|
@ -157,77 +168,69 @@ class ReverseGateway:
|
|||
logger.error(f"Unexpected error: {e}")
|
||||
if self.running:
|
||||
await asyncio.sleep(self.reconnect_delay)
|
||||
|
||||
|
||||
await self.shutdown()
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Shutting down reverse gateway")
|
||||
self.running = False
|
||||
await self.dispatcher.shutdown()
|
||||
await self.disconnect()
|
||||
|
||||
# Close backend
|
||||
if hasattr(self, 'backend'):
|
||||
self.backend.close()
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
def parse_args():
|
||||
|
||||
def run():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="reverse-gateway",
|
||||
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--id',
|
||||
default='rev-gateway',
|
||||
help='Service identifier (default: rev-gateway)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--websocket-uri',
|
||||
default=None,
|
||||
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)'
|
||||
help=f'WebSocket URI to connect to (default: {default_websocket})',
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--max-workers',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum concurrent message handlers (default: 10)'
|
||||
help='Maximum concurrent message handlers (default: 10)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=None,
|
||||
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-api-key',
|
||||
default=None,
|
||||
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-listener',
|
||||
default=None,
|
||||
help='Pulsar listener name'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def run():
|
||||
args = parse_args()
|
||||
|
||||
gateway = ReverseGateway(
|
||||
websocket_uri=args.websocket_uri,
|
||||
max_workers=args.max_workers,
|
||||
pulsar_host=args.pulsar_host,
|
||||
pulsar_api_key=args.pulsar_api_key,
|
||||
pulsar_listener=args.pulsar_listener
|
||||
parser.add_argument(
|
||||
'--timeout',
|
||||
type=int,
|
||||
default=default_timeout,
|
||||
help=f'Request timeout in seconds (default: {default_timeout})',
|
||||
)
|
||||
|
||||
|
||||
add_pubsub_args(parser)
|
||||
add_logging_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = vars(args)
|
||||
|
||||
setup_logging(args)
|
||||
|
||||
gateway = ReverseGateway(**args)
|
||||
|
||||
logger.info(f"Starting reverse gateway:")
|
||||
logger.info(f" WebSocket URI: {gateway.url}")
|
||||
logger.info(f" Max workers: {args.max_workers}")
|
||||
logger.info(f" Pulsar host: {gateway.pulsar_host}")
|
||||
|
||||
logger.info(f" Max workers: {gateway.max_workers}")
|
||||
|
||||
try:
|
||||
asyncio.run(gateway.run())
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue