mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-27 16:25:12 +02:00
Update rev-gateway for IAM integration (#940)
service.py: - Constructor takes **config (same pattern as api-gateway) instead of individual args - Creates IamAuth and calls await self.auth.start() before the message loop - Passes auth to both ConfigReceiver and MessageDispatcher - Uses add_pubsub_args / add_logging_args instead of hand-rolled Pulsar args - Passes timeout through dispatcher.py: - Accepts auth and timeout parameters - Passes both to DispatcherManager — fixes the missing auth argument that would have crashed on startup The remote end's requests now go through the same IAM authentication path as api-gateway. Token validation, workspace resolution, and permissions all work identically regardless of which direction initiated the connection. Fixed tests — the test now passes auth and timeout to MessageDispatcher and verifies they're forwarded to DispatcherManager. Update rev gateway dispatcher to align with IAM. A "token" parameter must be passed with each message. Fix websocket relay to align with rev-gateway changes, conforms to the api-gateway protocol.
This commit is contained in:
parent
4e3bd85abc
commit
e57f4669e1
7 changed files with 914 additions and 865 deletions
|
|
@ -3,208 +3,278 @@
|
|||
WebSocket Relay Test Harness
|
||||
|
||||
This script creates a relay server with two WebSocket endpoints:
|
||||
- /in - for test clients to connect to
|
||||
- /out - for reverse gateway to connect to
|
||||
- /in - for test clients to connect to (speaks api-gateway protocol)
|
||||
- /out - for reverse gateway to connect to (speaks rev-gateway protocol)
|
||||
|
||||
Messages are bidirectionally relayed between the two connections.
|
||||
Clients on /in authenticate with a first-frame auth message:
|
||||
{"type": "auth", "token": "..."}
|
||||
|
||||
The relay stores the token and injects it into each subsequent message
|
||||
before forwarding to /out. Responses from /out are forwarded back to
|
||||
the originating /in connection unchanged.
|
||||
|
||||
Usage:
|
||||
python websocket_relay.py [--port PORT] [--host HOST]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
from aiohttp import web, WSMsgType
|
||||
import weakref
|
||||
from typing import Optional, Set
|
||||
from typing import Dict, Optional
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("websocket_relay")
|
||||
|
||||
|
||||
class InConnection:
|
||||
def __init__(self, ws, conn_id):
|
||||
self.ws = ws
|
||||
self.conn_id = conn_id
|
||||
self.token: Optional[str] = None
|
||||
self.authenticated = False
|
||||
|
||||
|
||||
class WebSocketRelay:
|
||||
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.in_connections: Set = weakref.WeakSet()
|
||||
self.out_connections: Set = weakref.WeakSet()
|
||||
|
||||
self.in_connections: Dict[str, InConnection] = {}
|
||||
self.out_connections: set = set()
|
||||
self._conn_counter = 0
|
||||
|
||||
def _next_conn_id(self):
|
||||
self._conn_counter += 1
|
||||
return f"conn-{self._conn_counter}"
|
||||
|
||||
async def handle_in_connection(self, request):
|
||||
"""Handle incoming connections on /in endpoint"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.in_connections.add(ws)
|
||||
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
|
||||
conn_id = self._next_conn_id()
|
||||
conn = InConnection(ws, conn_id)
|
||||
self.in_connections[conn_id] = conn
|
||||
logger.info(
|
||||
f"New 'in' connection {conn_id}. "
|
||||
f"Total in: {len(self.in_connections)}, "
|
||||
f"out: {len(self.out_connections)}"
|
||||
)
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data = msg.data
|
||||
logger.info(f"IN → OUT: {data}")
|
||||
await self._forward_to_out(data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
data = msg.data
|
||||
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
|
||||
await self._forward_to_out(data, binary=True)
|
||||
await self._handle_in_message(conn, msg.data)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}")
|
||||
logger.error(
|
||||
f"WebSocket error on 'in' connection "
|
||||
f"{conn_id}: {ws.exception()}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'in' connection handler: {e}")
|
||||
logger.error(
|
||||
f"Error in 'in' connection {conn_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
del self.in_connections[conn_id]
|
||||
logger.info(
|
||||
f"'in' connection {conn_id} closed. "
|
||||
f"Remaining in: {len(self.in_connections)}, "
|
||||
f"out: {len(self.out_connections)}"
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
async def handle_out_connection(self, request):
|
||||
"""Handle outgoing connections on /out endpoint"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.out_connections.add(ws)
|
||||
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
|
||||
async def _handle_in_message(self, conn, data):
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data = msg.data
|
||||
logger.info(f"OUT → IN: {data}")
|
||||
await self._forward_to_in(data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
data = msg.data
|
||||
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
|
||||
await self._forward_to_in(data, binary=True)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'out' connection handler: {e}")
|
||||
finally:
|
||||
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def _forward_to_out(self, data, binary=False):
|
||||
"""Forward message from 'in' to all 'out' connections"""
|
||||
if not self.out_connections:
|
||||
logger.warning("No 'out' connections available to forward message")
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"{conn.conn_id}: received non-JSON message"
|
||||
)
|
||||
return
|
||||
|
||||
closed_connections = []
|
||||
|
||||
if isinstance(message, dict) and message.get("type") == "auth":
|
||||
conn.token = message.get("token", "")
|
||||
conn.authenticated = True
|
||||
logger.info(f"{conn.conn_id}: authenticated")
|
||||
await conn.ws.send_json({
|
||||
"type": "auth-ok",
|
||||
"workspace": "relayed",
|
||||
})
|
||||
return
|
||||
|
||||
if not conn.authenticated:
|
||||
await conn.ws.send_json({
|
||||
"error": {
|
||||
"message": "auth required",
|
||||
"type": "auth-required",
|
||||
},
|
||||
"complete": True,
|
||||
})
|
||||
return
|
||||
|
||||
message["token"] = conn.token
|
||||
message["_relay_conn"] = conn.conn_id
|
||||
|
||||
forwarded = json.dumps(message)
|
||||
logger.info(f"IN {conn.conn_id} → OUT: {forwarded}")
|
||||
await self._forward_to_out(forwarded)
|
||||
|
||||
async def handle_out_connection(self, request):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.out_connections.add(ws)
|
||||
logger.info(
|
||||
f"New 'out' connection. "
|
||||
f"Total in: {len(self.in_connections)}, "
|
||||
f"out: {len(self.out_connections)}"
|
||||
)
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await self._handle_out_message(msg.data)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(
|
||||
f"WebSocket error on 'out' connection: "
|
||||
f"{ws.exception()}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'out' connection: {e}")
|
||||
finally:
|
||||
self.out_connections.discard(ws)
|
||||
logger.info(
|
||||
f"'out' connection closed. "
|
||||
f"Remaining in: {len(self.in_connections)}, "
|
||||
f"out: {len(self.out_connections)}"
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
async def _handle_out_message(self, data):
|
||||
try:
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("OUT: received non-JSON message")
|
||||
return
|
||||
|
||||
conn_id = message.pop("_relay_conn", None)
|
||||
|
||||
forwarded = json.dumps(message)
|
||||
logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}")
|
||||
|
||||
if conn_id and conn_id in self.in_connections:
|
||||
conn = self.in_connections[conn_id]
|
||||
try:
|
||||
if not conn.ws.closed:
|
||||
await conn.ws.send_str(forwarded)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error forwarding to 'in' {conn_id}: {e}"
|
||||
)
|
||||
else:
|
||||
await self._broadcast_to_in(forwarded)
|
||||
|
||||
async def _broadcast_to_in(self, data):
|
||||
closed = []
|
||||
for conn_id, conn in list(self.in_connections.items()):
|
||||
try:
|
||||
if conn.ws.closed:
|
||||
closed.append(conn_id)
|
||||
continue
|
||||
await conn.ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error broadcasting to 'in' {conn_id}: {e}"
|
||||
)
|
||||
closed.append(conn_id)
|
||||
for conn_id in closed:
|
||||
self.in_connections.pop(conn_id, None)
|
||||
|
||||
async def _forward_to_out(self, data):
|
||||
closed = []
|
||||
for ws in list(self.out_connections):
|
||||
try:
|
||||
if ws.closed:
|
||||
closed_connections.append(ws)
|
||||
closed.append(ws)
|
||||
continue
|
||||
|
||||
if binary:
|
||||
await ws.send_bytes(data)
|
||||
else:
|
||||
await ws.send_str(data)
|
||||
await ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to 'out' connection: {e}")
|
||||
closed_connections.append(ws)
|
||||
|
||||
# Clean up closed connections
|
||||
for ws in closed_connections:
|
||||
if ws in self.out_connections:
|
||||
self.out_connections.discard(ws)
|
||||
|
||||
async def _forward_to_in(self, data, binary=False):
|
||||
"""Forward message from 'out' to all 'in' connections"""
|
||||
if not self.in_connections:
|
||||
logger.warning("No 'in' connections available to forward message")
|
||||
return
|
||||
|
||||
closed_connections = []
|
||||
for ws in list(self.in_connections):
|
||||
try:
|
||||
if ws.closed:
|
||||
closed_connections.append(ws)
|
||||
continue
|
||||
|
||||
if binary:
|
||||
await ws.send_bytes(data)
|
||||
else:
|
||||
await ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to 'in' connection: {e}")
|
||||
closed_connections.append(ws)
|
||||
|
||||
# Clean up closed connections
|
||||
for ws in closed_connections:
|
||||
if ws in self.in_connections:
|
||||
self.in_connections.discard(ws)
|
||||
logger.error(f"Error forwarding to 'out': {e}")
|
||||
closed.append(ws)
|
||||
for ws in closed:
|
||||
self.out_connections.discard(ws)
|
||||
|
||||
|
||||
async def create_app(relay):
|
||||
"""Create the web application with routes"""
|
||||
app = web.Application()
|
||||
|
||||
# Add routes
|
||||
app.router.add_get('/in', relay.handle_in_connection)
|
||||
|
||||
app.router.add_get('/in/api/v1/socket', relay.handle_in_connection)
|
||||
app.router.add_get('/out', relay.handle_out_connection)
|
||||
|
||||
# Add a simple status endpoint
|
||||
|
||||
async def status(request):
|
||||
status_info = {
|
||||
return web.json_response({
|
||||
'in_connections': len(relay.in_connections),
|
||||
'out_connections': len(relay.out_connections),
|
||||
'status': 'running'
|
||||
}
|
||||
return web.json_response(status_info)
|
||||
|
||||
'status': 'running',
|
||||
})
|
||||
|
||||
app.router.add_get('/status', status)
|
||||
app.router.add_get('/', status) # Root also shows status
|
||||
|
||||
app.router.add_get('/', status)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="WebSocket Relay Test Harness"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--host',
|
||||
'--host',
|
||||
default='localhost',
|
||||
help='Host to bind to (default: localhost)'
|
||||
help='Host to bind to (default: localhost)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
'--port',
|
||||
type=int,
|
||||
default=8080,
|
||||
help='Port to bind to (default: 8080)'
|
||||
help='Port to bind to (default: 8080)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
help='Enable verbose logging',
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
relay = WebSocketRelay()
|
||||
|
||||
|
||||
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
|
||||
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in")
|
||||
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket")
|
||||
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
|
||||
print(f" Status: http://{args.host}:{args.port}/status")
|
||||
print()
|
||||
print("Usage:")
|
||||
print(f" Test client connects to: ws://{args.host}:{args.port}/in")
|
||||
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out")
|
||||
|
||||
print("Client protocol (same as api-gateway):")
|
||||
print(' 1. Connect to /in/api/v1/socket')
|
||||
print(' 2. Send: {"type": "auth", "token": "tg_..."}')
|
||||
print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}')
|
||||
print(' 4. Send requests as normal')
|
||||
|
||||
web.run_app(create_app(relay), host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue