Update rev-gateway for IAM integration (#940)

service.py:
- Constructor takes **config (same pattern as api-gateway) instead
  of individual args
- Creates IamAuth and calls await self.auth.start() before the
  message loop
- Passes auth to both ConfigReceiver and MessageDispatcher
- Uses add_pubsub_args / add_logging_args instead of hand-rolled
  Pulsar args
- Passes timeout through

dispatcher.py:
- Accepts auth and timeout parameters
- Passes both to DispatcherManager — fixes the missing auth argument
  that would have crashed on startup

The remote end's requests now go through the same IAM authentication
path as api-gateway. Token validation, workspace resolution, and
permissions all work identically regardless of which direction
initiated the connection.

Fixed tests — the test now passes auth and timeout to MessageDispatcher
and verifies they're forwarded to DispatcherManager.

Update rev gateway dispatcher to align with IAM.  A "token" parameter
must be passed with each message.

Fix websocket relay to align with rev-gateway changes, conforms to
the api-gateway protocol.
This commit is contained in:
cybermaggedon 2026-05-19 21:45:43 +01:00 committed by GitHub
parent 4e3bd85abc
commit e57f4669e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 914 additions and 865 deletions

View file

@ -1,130 +1,140 @@
import asyncio
import logging
import uuid
from typing import Dict, Any, Optional
from trustgraph.messaging import TranslatorRegistry
from typing import Dict, Any, Optional, Callable, Awaitable
from ..gateway.dispatch.manager import DispatcherManager
logger = logging.getLogger("dispatcher")
logger.setLevel(logging.INFO)
class WebSocketResponder:
"""Simple responder that captures response for websocket return"""
def __init__(self):
self.response = None
self.completed = False
async def send(self, data):
"""Capture the response data"""
self.response = data
self.completed = True
async def __call__(self, data, final=False):
"""Make the responder callable for compatibility with requestor"""
await self.send(data)
if final:
self.completed = True
class _TokenShim:
def __init__(self, token):
self.headers = (
{"Authorization": f"Bearer {token}"} if token else {}
)
class MessageDispatcher:
def __init__(self, max_workers: int = 10, config_receiver=None, backend=None):
def __init__(self, max_workers=10, config_receiver=None, backend=None,
auth=None, timeout=120):
self.max_workers = max_workers
self.semaphore = asyncio.Semaphore(max_workers)
self.active_tasks = set()
self.backend = backend
self.auth = auth
# Use DispatcherManager for flow and service management
if backend and config_receiver:
self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway")
if backend and config_receiver and auth:
self.dispatcher_manager = DispatcherManager(
backend, config_receiver,
auth=auth,
prefix="rev-gateway",
timeout=timeout,
)
else:
self.dispatcher_manager = None
logger.warning("No backend or config_receiver provided - using fallback mode")
# Service name mapping from websocket protocol to translator registry
logger.warning(
"Missing backend, config_receiver, or auth "
"— using fallback mode"
)
self.service_mapping = {
"text-completion": "text-completion",
"graph-rag": "graph-rag",
"graph-rag": "graph-rag",
"agent": "agent",
"embeddings": "embeddings",
"graph-embeddings": "graph-embeddings",
"triples": "triples",
"document-load": "document",
"text-load": "text-document",
"text-load": "text-document",
"flow": "flow",
"knowledge": "knowledge",
"config": "config",
"librarian": "librarian",
"document-rag": "document-rag"
"document-rag": "document-rag",
}
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
async def handle_message(
self, message: Dict[Any, Any],
sender: Callable[[dict], Awaitable[None]],
):
async with self.semaphore:
task = asyncio.create_task(self._process_message(message))
task = asyncio.create_task(
self._process_message(message, sender)
)
self.active_tasks.add(task)
try:
result = await task
return result
await task
finally:
self.active_tasks.discard(task)
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
async def _authenticate(self, token):
if not self.auth:
raise RuntimeError("Auth not configured")
return await self.auth.authenticate(_TokenShim(token))
async def _process_message(
self, message: Dict[Any, Any],
sender: Callable[[dict], Awaitable[None]],
):
request_id = message.get('id', str(uuid.uuid4()))
service = message.get('service')
request_data = message.get('request', {})
flow_id = message.get('flow', 'default') # Default flow
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
token = message.get('token', '')
flow_id = message.get('flow', 'default')
logger.info(
f"Processing message {request_id} for service "
f"{service} on flow {flow_id}"
)
try:
if not self.dispatcher_manager:
raise RuntimeError("DispatcherManager not available - backend and config_receiver required")
# Use DispatcherManager for flow-based processing
responder = WebSocketResponder()
# Map websocket service name to dispatcher service name
raise RuntimeError(
"DispatcherManager not available"
)
identity = await self._authenticate(token)
workspace = identity.workspace
async def responder(resp, fin):
await sender({
"id": request_id,
"response": resp,
"complete": fin,
})
dispatcher_service = self.service_mapping.get(service, service)
# Check if this is a global service or flow service
from ..gateway.dispatch.manager import global_dispatchers
if dispatcher_service in global_dispatchers:
# Use global service dispatcher
await self.dispatcher_manager.invoke_global_service(
request_data, responder, dispatcher_service
request_data, responder, dispatcher_service,
workspace=workspace,
)
else:
# Use DispatcherManager to process the request through Pulsar queues
await self.dispatcher_manager.invoke_flow_service(
request_data, responder, flow_id, dispatcher_service
request_data, responder, workspace, flow_id,
dispatcher_service,
)
# Get the response from the responder
if responder.completed:
response_data = responder.response
else:
response_data = {'error': 'No response received'}
response = {
'id': request_id,
'response': response_data
}
except Exception as e:
logger.error(f"Error processing message {request_id}: {e}")
response = {
'id': request_id,
'response': {'error': str(e)}
}
await sender({
"id": request_id,
"error": {"message": str(e), "type": "error"},
"complete": True,
})
logger.info(f"Completed processing message {request_id}")
return response
async def shutdown(self):
if self.active_tasks:
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
logger.info(
f"Waiting for {len(self.active_tasks)} active "
f"tasks to complete"
)
await asyncio.gather(*self.active_tasks, return_exceptions=True)
# DispatcherManager handles its own cleanup
logger.info("Dispatcher shutdown complete")

View file

@ -1,3 +1,9 @@
"""
Reverse gateway. Initiates outbound WebSocket connections to a remote
relay and dispatches incoming requests through the same DispatcherManager
pipeline as api-gateway.
"""
import asyncio
import argparse
import logging
@ -9,82 +15,86 @@ from typing import Optional
from urllib.parse import urlparse, urlunparse
from .dispatcher import MessageDispatcher
from ..gateway.auth import IamAuth
from ..gateway.config.receiver import ConfigReceiver
from ..base import get_pubsub
from ..base.pubsub import get_pubsub, add_pubsub_args
from ..base.logging import setup_logging, add_logging_args
logger = logging.getLogger("rev_gateway")
logger.setLevel(logging.INFO)
default_websocket = "ws://localhost:7650/out"
default_timeout = 600
class ReverseGateway:
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
pulsar_host: str = None, pulsar_api_key: str = None,
pulsar_listener: str = None):
# Set default WebSocket URI with environment variable support
def __init__(self, **config):
websocket_uri = config.get("websocket_uri")
if websocket_uri is None:
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
# Parse and validate the WebSocket URI
parsed_uri = urlparse(websocket_uri)
if parsed_uri.scheme not in ('ws', 'wss'):
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
raise ValueError(
f"WebSocket URI must use ws:// or wss:// scheme, "
f"got: {parsed_uri.scheme}"
)
if not parsed_uri.netloc:
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
# Store parsed components for debugging/logging
raise ValueError(
f"WebSocket URI must include hostname, "
f"got: {websocket_uri}"
)
self.websocket_uri = websocket_uri
self.host = parsed_uri.hostname
self.port = parsed_uri.port
self.scheme = parsed_uri.scheme
self.path = parsed_uri.path or "/ws"
# Construct the full URL (in case path was missing)
if not parsed_uri.path:
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
else:
self.url = websocket_uri
self.max_workers = max_workers
self.max_workers = int(config.get("max_workers", 10))
self.timeout = int(config.get("timeout", default_timeout))
self.ws: Optional[ClientWebSocketResponse] = None
self.session: Optional[ClientSession] = None
self.running = False
self.reconnect_delay = 3.0
# Pulsar configuration
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
self.pulsar_listener = pulsar_listener
# Create backend using factory
backend_params = {
'pulsar_host': self.pulsar_host,
'pulsar_api_key': self.pulsar_api_key,
'pulsar_listener': self.pulsar_listener,
}
self.backend = get_pubsub(**backend_params)
self.backend = get_pubsub(**config)
# Initialize config receiver
self.config_receiver = ConfigReceiver(self.backend)
self.auth = IamAuth(
backend=self.backend,
id=config.get("id", "rev-gateway"),
)
self.config_receiver = ConfigReceiver(
self.backend, auth=self.auth,
)
self.dispatcher = MessageDispatcher(
self.max_workers, self.config_receiver, self.backend,
auth=self.auth, timeout=self.timeout,
)
# Initialize dispatcher with config_receiver and backend - must be created after config_receiver
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend)
async def connect(self) -> bool:
try:
if self.session is None:
self.session = ClientSession()
logger.info(f"Connecting to {self.url}")
self.ws = await self.session.ws_connect(self.url)
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
logger.info(
f"WebSocket connection established to "
f"{self.host}:{self.port or 'default'}"
)
return True
except Exception as e:
logger.error(f"Failed to connect to {self.url}: {e}")
return False
async def disconnect(self):
if self.ws and not self.ws.closed:
await self.ws.close()
@ -92,32 +102,31 @@ class ReverseGateway:
await self.session.close()
self.ws = None
self.session = None
async def send_message(self, message: dict):
if self.ws and not self.ws.closed:
try:
await self.ws.send_str(json.dumps(message))
except Exception as e:
logger.error(f"Failed to send message: {e}")
async def handle_message(self, message: str):
try:
logger.debug(f"Received message: {message}")
msg_data = json.loads(message)
response = await self.dispatcher.handle_message(msg_data)
if response:
await self.send_message(response)
await self.dispatcher.handle_message(
msg_data, self.send_message,
)
except Exception as e:
logger.error(f"Error handling message: {e}")
async def listen(self):
while self.running and self.ws and not self.ws.closed:
try:
msg = await self.ws.receive()
if msg.type == WSMsgType.TEXT:
await self.handle_message(msg.data)
elif msg.type == WSMsgType.BINARY:
@ -125,31 +134,33 @@ class ReverseGateway:
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
logger.warning("WebSocket closed or error occurred")
break
except Exception as e:
logger.error(f"Error in listen loop: {e}")
break
async def run(self):
self.running = True
logger.info("Starting reverse gateway")
# Start config receiver
logger.info("Starting config receiver")
await self.auth.start()
await self.config_receiver.start()
while self.running:
try:
if await self.connect():
await self.listen()
else:
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
logger.warning(
f"Connection failed, retrying in "
f"{self.reconnect_delay} seconds"
)
await self.disconnect()
if self.running:
await asyncio.sleep(self.reconnect_delay)
except KeyboardInterrupt:
logger.info("Shutdown requested")
break
@ -157,77 +168,69 @@ class ReverseGateway:
logger.error(f"Unexpected error: {e}")
if self.running:
await asyncio.sleep(self.reconnect_delay)
await self.shutdown()
async def shutdown(self):
logger.info("Shutting down reverse gateway")
self.running = False
await self.dispatcher.shutdown()
await self.disconnect()
# Close backend
if hasattr(self, 'backend'):
self.backend.close()
def stop(self):
self.running = False
def parse_args():
def run():
parser = argparse.ArgumentParser(
prog="reverse-gateway",
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
description=__doc__,
)
parser.add_argument(
'--id',
default='rev-gateway',
help='Service identifier (default: rev-gateway)',
)
parser.add_argument(
'--websocket-uri',
default=None,
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)'
help=f'WebSocket URI to connect to (default: {default_websocket})',
)
parser.add_argument(
'--max-workers',
type=int,
default=10,
help='Maximum concurrent message handlers (default: 10)'
help='Maximum concurrent message handlers (default: 10)',
)
parser.add_argument(
'-p', '--pulsar-host',
default=None,
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
)
parser.add_argument(
'--pulsar-api-key',
default=None,
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
)
parser.add_argument(
'--pulsar-listener',
default=None,
help='Pulsar listener name'
)
return parser.parse_args()
def run():
args = parse_args()
gateway = ReverseGateway(
websocket_uri=args.websocket_uri,
max_workers=args.max_workers,
pulsar_host=args.pulsar_host,
pulsar_api_key=args.pulsar_api_key,
pulsar_listener=args.pulsar_listener
parser.add_argument(
'--timeout',
type=int,
default=default_timeout,
help=f'Request timeout in seconds (default: {default_timeout})',
)
add_pubsub_args(parser)
add_logging_args(parser)
args = parser.parse_args()
args = vars(args)
setup_logging(args)
gateway = ReverseGateway(**args)
logger.info(f"Starting reverse gateway:")
logger.info(f" WebSocket URI: {gateway.url}")
logger.info(f" Max workers: {args.max_workers}")
logger.info(f" Pulsar host: {gateway.pulsar_host}")
logger.info(f" Max workers: {gateway.max_workers}")
try:
asyncio.run(gateway.run())
except KeyboardInterrupt: