trustgraph/dev-tools/tests/relay/websocket_relay.py
cybermaggedon e57f4669e1
Update rev-gateway for IAM integration (#940)
service.py:
- Constructor takes **config (same pattern as api-gateway) instead
  of individual args
- Creates IamAuth and calls await self.auth.start() before the
  message loop
- Passes auth to both ConfigReceiver and MessageDispatcher
- Uses add_pubsub_args / add_logging_args instead of hand-rolled
  Pulsar args
- Passes timeout through

dispatcher.py:
- Accepts auth and timeout parameters
- Passes both to DispatcherManager — fixes the missing auth argument
  that would have crashed on startup

The remote end's requests now go through the same IAM authentication
path as api-gateway. Token validation, workspace resolution, and
permissions all work identically regardless of which direction
initiated the connection.

Fixed tests — the test now passes auth and timeout to MessageDispatcher
and verifies they're forwarded to DispatcherManager.

Update rev gateway dispatcher to align with IAM.  A "token" parameter
must be passed with each message.

Fix websocket relay to align with rev-gateway changes, conforms to
the api-gateway protocol.
2026-05-19 21:45:43 +01:00

280 lines
8.4 KiB
Python

#!/usr/bin/env python3
"""
WebSocket Relay Test Harness
This script creates a relay server with two WebSocket endpoints:
- /in - for test clients to connect to (speaks api-gateway protocol)
- /out - for reverse gateway to connect to (speaks rev-gateway protocol)
Clients on /in authenticate with a first-frame auth message:
{"type": "auth", "token": "..."}
The relay stores the token and injects it into each subsequent message
before forwarding to /out. Responses from /out are forwarded back to
the originating /in connection unchanged.
Usage:
python websocket_relay.py [--port PORT] [--host HOST]
"""
import asyncio
import json
import logging
import argparse
from aiohttp import web, WSMsgType
from typing import Dict, Optional
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("websocket_relay")
class InConnection:
def __init__(self, ws, conn_id):
self.ws = ws
self.conn_id = conn_id
self.token: Optional[str] = None
self.authenticated = False
class WebSocketRelay:
def __init__(self):
self.in_connections: Dict[str, InConnection] = {}
self.out_connections: set = set()
self._conn_counter = 0
def _next_conn_id(self):
self._conn_counter += 1
return f"conn-{self._conn_counter}"
async def handle_in_connection(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
conn_id = self._next_conn_id()
conn = InConnection(ws, conn_id)
self.in_connections[conn_id] = conn
logger.info(
f"New 'in' connection {conn_id}. "
f"Total in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
await self._handle_in_message(conn, msg.data)
elif msg.type == WSMsgType.ERROR:
logger.error(
f"WebSocket error on 'in' connection "
f"{conn_id}: {ws.exception()}"
)
break
else:
break
except Exception as e:
logger.error(
f"Error in 'in' connection {conn_id}: {e}"
)
finally:
del self.in_connections[conn_id]
logger.info(
f"'in' connection {conn_id} closed. "
f"Remaining in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
return ws
async def _handle_in_message(self, conn, data):
try:
message = json.loads(data)
except json.JSONDecodeError:
logger.warning(
f"{conn.conn_id}: received non-JSON message"
)
return
if isinstance(message, dict) and message.get("type") == "auth":
conn.token = message.get("token", "")
conn.authenticated = True
logger.info(f"{conn.conn_id}: authenticated")
await conn.ws.send_json({
"type": "auth-ok",
"workspace": "relayed",
})
return
if not conn.authenticated:
await conn.ws.send_json({
"error": {
"message": "auth required",
"type": "auth-required",
},
"complete": True,
})
return
message["token"] = conn.token
message["_relay_conn"] = conn.conn_id
forwarded = json.dumps(message)
logger.info(f"IN {conn.conn_id} → OUT: {forwarded}")
await self._forward_to_out(forwarded)
async def handle_out_connection(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.out_connections.add(ws)
logger.info(
f"New 'out' connection. "
f"Total in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
await self._handle_out_message(msg.data)
elif msg.type == WSMsgType.ERROR:
logger.error(
f"WebSocket error on 'out' connection: "
f"{ws.exception()}"
)
break
else:
break
except Exception as e:
logger.error(f"Error in 'out' connection: {e}")
finally:
self.out_connections.discard(ws)
logger.info(
f"'out' connection closed. "
f"Remaining in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
return ws
async def _handle_out_message(self, data):
try:
message = json.loads(data)
except json.JSONDecodeError:
logger.warning("OUT: received non-JSON message")
return
conn_id = message.pop("_relay_conn", None)
forwarded = json.dumps(message)
logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}")
if conn_id and conn_id in self.in_connections:
conn = self.in_connections[conn_id]
try:
if not conn.ws.closed:
await conn.ws.send_str(forwarded)
except Exception as e:
logger.error(
f"Error forwarding to 'in' {conn_id}: {e}"
)
else:
await self._broadcast_to_in(forwarded)
async def _broadcast_to_in(self, data):
closed = []
for conn_id, conn in list(self.in_connections.items()):
try:
if conn.ws.closed:
closed.append(conn_id)
continue
await conn.ws.send_str(data)
except Exception as e:
logger.error(
f"Error broadcasting to 'in' {conn_id}: {e}"
)
closed.append(conn_id)
for conn_id in closed:
self.in_connections.pop(conn_id, None)
async def _forward_to_out(self, data):
closed = []
for ws in list(self.out_connections):
try:
if ws.closed:
closed.append(ws)
continue
await ws.send_str(data)
except Exception as e:
logger.error(f"Error forwarding to 'out': {e}")
closed.append(ws)
for ws in closed:
self.out_connections.discard(ws)
async def create_app(relay):
app = web.Application()
app.router.add_get('/in/api/v1/socket', relay.handle_in_connection)
app.router.add_get('/out', relay.handle_out_connection)
async def status(request):
return web.json_response({
'in_connections': len(relay.in_connections),
'out_connections': len(relay.out_connections),
'status': 'running',
})
app.router.add_get('/status', status)
app.router.add_get('/', status)
return app
def main():
parser = argparse.ArgumentParser(
description="WebSocket Relay Test Harness"
)
parser.add_argument(
'--host',
default='localhost',
help='Host to bind to (default: localhost)',
)
parser.add_argument(
'--port',
type=int,
default=8080,
help='Port to bind to (default: 8080)',
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Enable verbose logging',
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
relay = WebSocketRelay()
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket")
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
print(f" Status: http://{args.host}:{args.port}/status")
print()
print("Client protocol (same as api-gateway):")
print(' 1. Connect to /in/api/v1/socket')
print(' 2. Send: {"type": "auth", "token": "tg_..."}')
print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}')
print(' 4. Send requests as normal')
web.run_app(create_app(relay), host=args.host, port=args.port)
if __name__ == "__main__":
main()