mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-27 16:25:12 +02:00
Update rev-gateway for IAM integration (#940)
service.py: - Constructor takes **config (same pattern as api-gateway) instead of individual args - Creates IamAuth and calls await self.auth.start() before the message loop - Passes auth to both ConfigReceiver and MessageDispatcher - Uses add_pubsub_args / add_logging_args instead of hand-rolled Pulsar args - Passes timeout through dispatcher.py: - Accepts auth and timeout parameters - Passes both to DispatcherManager — fixes the missing auth argument that would have crashed on startup The remote end's requests now go through the same IAM authentication path as api-gateway. Token validation, workspace resolution, and permissions all work identically regardless of which direction initiated the connection. Fixed tests — the test now passes auth and timeout to MessageDispatcher and verifies they're forwarded to DispatcherManager. Update rev gateway dispatcher to align with IAM. A "token" parameter must be passed with each message. Fix websocket relay to align with rev-gateway changes, conforms to the api-gateway protocol.
This commit is contained in:
parent
4e3bd85abc
commit
e57f4669e1
7 changed files with 914 additions and 865 deletions
|
|
@ -1,130 +1,140 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from trustgraph.messaging import TranslatorRegistry
|
||||
from typing import Dict, Any, Optional, Callable, Awaitable
|
||||
from ..gateway.dispatch.manager import DispatcherManager
|
||||
|
||||
logger = logging.getLogger("dispatcher")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class WebSocketResponder:
|
||||
"""Simple responder that captures response for websocket return"""
|
||||
def __init__(self):
|
||||
self.response = None
|
||||
self.completed = False
|
||||
|
||||
async def send(self, data):
|
||||
"""Capture the response data"""
|
||||
self.response = data
|
||||
self.completed = True
|
||||
|
||||
async def __call__(self, data, final=False):
|
||||
"""Make the responder callable for compatibility with requestor"""
|
||||
await self.send(data)
|
||||
if final:
|
||||
self.completed = True
|
||||
|
||||
class _TokenShim:
|
||||
def __init__(self, token):
|
||||
self.headers = (
|
||||
{"Authorization": f"Bearer {token}"} if token else {}
|
||||
)
|
||||
|
||||
|
||||
class MessageDispatcher:
|
||||
|
||||
def __init__(self, max_workers: int = 10, config_receiver=None, backend=None):
|
||||
def __init__(self, max_workers=10, config_receiver=None, backend=None,
|
||||
auth=None, timeout=120):
|
||||
self.max_workers = max_workers
|
||||
self.semaphore = asyncio.Semaphore(max_workers)
|
||||
self.active_tasks = set()
|
||||
self.backend = backend
|
||||
self.auth = auth
|
||||
|
||||
# Use DispatcherManager for flow and service management
|
||||
if backend and config_receiver:
|
||||
self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway")
|
||||
if backend and config_receiver and auth:
|
||||
self.dispatcher_manager = DispatcherManager(
|
||||
backend, config_receiver,
|
||||
auth=auth,
|
||||
prefix="rev-gateway",
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
self.dispatcher_manager = None
|
||||
logger.warning("No backend or config_receiver provided - using fallback mode")
|
||||
|
||||
# Service name mapping from websocket protocol to translator registry
|
||||
logger.warning(
|
||||
"Missing backend, config_receiver, or auth "
|
||||
"— using fallback mode"
|
||||
)
|
||||
|
||||
self.service_mapping = {
|
||||
"text-completion": "text-completion",
|
||||
"graph-rag": "graph-rag",
|
||||
"graph-rag": "graph-rag",
|
||||
"agent": "agent",
|
||||
"embeddings": "embeddings",
|
||||
"graph-embeddings": "graph-embeddings",
|
||||
"triples": "triples",
|
||||
"document-load": "document",
|
||||
"text-load": "text-document",
|
||||
"text-load": "text-document",
|
||||
"flow": "flow",
|
||||
"knowledge": "knowledge",
|
||||
"config": "config",
|
||||
"librarian": "librarian",
|
||||
"document-rag": "document-rag"
|
||||
"document-rag": "document-rag",
|
||||
}
|
||||
|
||||
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
||||
|
||||
async def handle_message(
|
||||
self, message: Dict[Any, Any],
|
||||
sender: Callable[[dict], Awaitable[None]],
|
||||
):
|
||||
async with self.semaphore:
|
||||
task = asyncio.create_task(self._process_message(message))
|
||||
task = asyncio.create_task(
|
||||
self._process_message(message, sender)
|
||||
)
|
||||
self.active_tasks.add(task)
|
||||
|
||||
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
await task
|
||||
finally:
|
||||
self.active_tasks.discard(task)
|
||||
|
||||
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
|
||||
async def _authenticate(self, token):
|
||||
if not self.auth:
|
||||
raise RuntimeError("Auth not configured")
|
||||
return await self.auth.authenticate(_TokenShim(token))
|
||||
|
||||
async def _process_message(
|
||||
self, message: Dict[Any, Any],
|
||||
sender: Callable[[dict], Awaitable[None]],
|
||||
):
|
||||
request_id = message.get('id', str(uuid.uuid4()))
|
||||
service = message.get('service')
|
||||
request_data = message.get('request', {})
|
||||
flow_id = message.get('flow', 'default') # Default flow
|
||||
|
||||
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
|
||||
|
||||
token = message.get('token', '')
|
||||
flow_id = message.get('flow', 'default')
|
||||
|
||||
logger.info(
|
||||
f"Processing message {request_id} for service "
|
||||
f"{service} on flow {flow_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.dispatcher_manager:
|
||||
raise RuntimeError("DispatcherManager not available - backend and config_receiver required")
|
||||
|
||||
# Use DispatcherManager for flow-based processing
|
||||
responder = WebSocketResponder()
|
||||
|
||||
# Map websocket service name to dispatcher service name
|
||||
raise RuntimeError(
|
||||
"DispatcherManager not available"
|
||||
)
|
||||
|
||||
identity = await self._authenticate(token)
|
||||
workspace = identity.workspace
|
||||
|
||||
async def responder(resp, fin):
|
||||
await sender({
|
||||
"id": request_id,
|
||||
"response": resp,
|
||||
"complete": fin,
|
||||
})
|
||||
|
||||
dispatcher_service = self.service_mapping.get(service, service)
|
||||
|
||||
# Check if this is a global service or flow service
|
||||
|
||||
from ..gateway.dispatch.manager import global_dispatchers
|
||||
if dispatcher_service in global_dispatchers:
|
||||
# Use global service dispatcher
|
||||
await self.dispatcher_manager.invoke_global_service(
|
||||
request_data, responder, dispatcher_service
|
||||
request_data, responder, dispatcher_service,
|
||||
workspace=workspace,
|
||||
)
|
||||
else:
|
||||
# Use DispatcherManager to process the request through Pulsar queues
|
||||
await self.dispatcher_manager.invoke_flow_service(
|
||||
request_data, responder, flow_id, dispatcher_service
|
||||
request_data, responder, workspace, flow_id,
|
||||
dispatcher_service,
|
||||
)
|
||||
|
||||
# Get the response from the responder
|
||||
if responder.completed:
|
||||
response_data = responder.response
|
||||
else:
|
||||
response_data = {'error': 'No response received'}
|
||||
|
||||
response = {
|
||||
'id': request_id,
|
||||
'response': response_data
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message {request_id}: {e}")
|
||||
response = {
|
||||
'id': request_id,
|
||||
'response': {'error': str(e)}
|
||||
}
|
||||
|
||||
await sender({
|
||||
"id": request_id,
|
||||
"error": {"message": str(e), "type": "error"},
|
||||
"complete": True,
|
||||
})
|
||||
|
||||
logger.info(f"Completed processing message {request_id}")
|
||||
return response
|
||||
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
if self.active_tasks:
|
||||
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
|
||||
logger.info(
|
||||
f"Waiting for {len(self.active_tasks)} active "
|
||||
f"tasks to complete"
|
||||
)
|
||||
await asyncio.gather(*self.active_tasks, return_exceptions=True)
|
||||
|
||||
# DispatcherManager handles its own cleanup
|
||||
|
||||
logger.info("Dispatcher shutdown complete")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
"""
|
||||
Reverse gateway. Initiates outbound WebSocket connections to a remote
|
||||
relay and dispatches incoming requests through the same DispatcherManager
|
||||
pipeline as api-gateway.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
|
|
@ -9,82 +15,86 @@ from typing import Optional
|
|||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from .dispatcher import MessageDispatcher
|
||||
from ..gateway.auth import IamAuth
|
||||
from ..gateway.config.receiver import ConfigReceiver
|
||||
from ..base import get_pubsub
|
||||
from ..base.pubsub import get_pubsub, add_pubsub_args
|
||||
from ..base.logging import setup_logging, add_logging_args
|
||||
|
||||
logger = logging.getLogger("rev_gateway")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
default_websocket = "ws://localhost:7650/out"
|
||||
default_timeout = 600
|
||||
|
||||
class ReverseGateway:
|
||||
|
||||
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
|
||||
pulsar_host: str = None, pulsar_api_key: str = None,
|
||||
pulsar_listener: str = None):
|
||||
# Set default WebSocket URI with environment variable support
|
||||
|
||||
def __init__(self, **config):
|
||||
websocket_uri = config.get("websocket_uri")
|
||||
if websocket_uri is None:
|
||||
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
|
||||
|
||||
# Parse and validate the WebSocket URI
|
||||
|
||||
parsed_uri = urlparse(websocket_uri)
|
||||
if parsed_uri.scheme not in ('ws', 'wss'):
|
||||
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
|
||||
raise ValueError(
|
||||
f"WebSocket URI must use ws:// or wss:// scheme, "
|
||||
f"got: {parsed_uri.scheme}"
|
||||
)
|
||||
if not parsed_uri.netloc:
|
||||
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
|
||||
|
||||
# Store parsed components for debugging/logging
|
||||
raise ValueError(
|
||||
f"WebSocket URI must include hostname, "
|
||||
f"got: {websocket_uri}"
|
||||
)
|
||||
|
||||
self.websocket_uri = websocket_uri
|
||||
self.host = parsed_uri.hostname
|
||||
self.port = parsed_uri.port
|
||||
self.scheme = parsed_uri.scheme
|
||||
self.path = parsed_uri.path or "/ws"
|
||||
|
||||
# Construct the full URL (in case path was missing)
|
||||
|
||||
if not parsed_uri.path:
|
||||
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
|
||||
else:
|
||||
self.url = websocket_uri
|
||||
|
||||
self.max_workers = max_workers
|
||||
|
||||
self.max_workers = int(config.get("max_workers", 10))
|
||||
self.timeout = int(config.get("timeout", default_timeout))
|
||||
self.ws: Optional[ClientWebSocketResponse] = None
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.running = False
|
||||
self.reconnect_delay = 3.0
|
||||
|
||||
# Pulsar configuration
|
||||
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
|
||||
self.pulsar_listener = pulsar_listener
|
||||
|
||||
# Create backend using factory
|
||||
backend_params = {
|
||||
'pulsar_host': self.pulsar_host,
|
||||
'pulsar_api_key': self.pulsar_api_key,
|
||||
'pulsar_listener': self.pulsar_listener,
|
||||
}
|
||||
self.backend = get_pubsub(**backend_params)
|
||||
self.backend = get_pubsub(**config)
|
||||
|
||||
# Initialize config receiver
|
||||
self.config_receiver = ConfigReceiver(self.backend)
|
||||
self.auth = IamAuth(
|
||||
backend=self.backend,
|
||||
id=config.get("id", "rev-gateway"),
|
||||
)
|
||||
|
||||
self.config_receiver = ConfigReceiver(
|
||||
self.backend, auth=self.auth,
|
||||
)
|
||||
|
||||
self.dispatcher = MessageDispatcher(
|
||||
self.max_workers, self.config_receiver, self.backend,
|
||||
auth=self.auth, timeout=self.timeout,
|
||||
)
|
||||
|
||||
# Initialize dispatcher with config_receiver and backend - must be created after config_receiver
|
||||
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
try:
|
||||
if self.session is None:
|
||||
self.session = ClientSession()
|
||||
|
||||
|
||||
logger.info(f"Connecting to {self.url}")
|
||||
self.ws = await self.session.ws_connect(self.url)
|
||||
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
|
||||
logger.info(
|
||||
f"WebSocket connection established to "
|
||||
f"{self.host}:{self.port or 'default'}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {self.url}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
|
@ -92,32 +102,31 @@ class ReverseGateway:
|
|||
await self.session.close()
|
||||
self.ws = None
|
||||
self.session = None
|
||||
|
||||
|
||||
async def send_message(self, message: dict):
|
||||
if self.ws and not self.ws.closed:
|
||||
try:
|
||||
await self.ws.send_str(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message: {e}")
|
||||
|
||||
|
||||
async def handle_message(self, message: str):
|
||||
try:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
|
||||
msg_data = json.loads(message)
|
||||
response = await self.dispatcher.handle_message(msg_data)
|
||||
|
||||
if response:
|
||||
await self.send_message(response)
|
||||
|
||||
await self.dispatcher.handle_message(
|
||||
msg_data, self.send_message,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
|
||||
|
||||
async def listen(self):
|
||||
while self.running and self.ws and not self.ws.closed:
|
||||
try:
|
||||
msg = await self.ws.receive()
|
||||
|
||||
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await self.handle_message(msg.data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
|
|
@ -125,31 +134,33 @@ class ReverseGateway:
|
|||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
|
||||
logger.warning("WebSocket closed or error occurred")
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen loop: {e}")
|
||||
break
|
||||
|
||||
|
||||
async def run(self):
|
||||
self.running = True
|
||||
logger.info("Starting reverse gateway")
|
||||
|
||||
# Start config receiver
|
||||
logger.info("Starting config receiver")
|
||||
|
||||
await self.auth.start()
|
||||
await self.config_receiver.start()
|
||||
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if await self.connect():
|
||||
await self.listen()
|
||||
else:
|
||||
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
|
||||
|
||||
logger.warning(
|
||||
f"Connection failed, retrying in "
|
||||
f"{self.reconnect_delay} seconds"
|
||||
)
|
||||
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
if self.running:
|
||||
await asyncio.sleep(self.reconnect_delay)
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutdown requested")
|
||||
break
|
||||
|
|
@ -157,77 +168,69 @@ class ReverseGateway:
|
|||
logger.error(f"Unexpected error: {e}")
|
||||
if self.running:
|
||||
await asyncio.sleep(self.reconnect_delay)
|
||||
|
||||
|
||||
await self.shutdown()
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Shutting down reverse gateway")
|
||||
self.running = False
|
||||
await self.dispatcher.shutdown()
|
||||
await self.disconnect()
|
||||
|
||||
# Close backend
|
||||
if hasattr(self, 'backend'):
|
||||
self.backend.close()
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
def parse_args():
|
||||
|
||||
def run():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="reverse-gateway",
|
||||
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--id',
|
||||
default='rev-gateway',
|
||||
help='Service identifier (default: rev-gateway)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--websocket-uri',
|
||||
default=None,
|
||||
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)'
|
||||
help=f'WebSocket URI to connect to (default: {default_websocket})',
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--max-workers',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum concurrent message handlers (default: 10)'
|
||||
help='Maximum concurrent message handlers (default: 10)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=None,
|
||||
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-api-key',
|
||||
default=None,
|
||||
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-listener',
|
||||
default=None,
|
||||
help='Pulsar listener name'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def run():
|
||||
args = parse_args()
|
||||
|
||||
gateway = ReverseGateway(
|
||||
websocket_uri=args.websocket_uri,
|
||||
max_workers=args.max_workers,
|
||||
pulsar_host=args.pulsar_host,
|
||||
pulsar_api_key=args.pulsar_api_key,
|
||||
pulsar_listener=args.pulsar_listener
|
||||
parser.add_argument(
|
||||
'--timeout',
|
||||
type=int,
|
||||
default=default_timeout,
|
||||
help=f'Request timeout in seconds (default: {default_timeout})',
|
||||
)
|
||||
|
||||
|
||||
add_pubsub_args(parser)
|
||||
add_logging_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = vars(args)
|
||||
|
||||
setup_logging(args)
|
||||
|
||||
gateway = ReverseGateway(**args)
|
||||
|
||||
logger.info(f"Starting reverse gateway:")
|
||||
logger.info(f" WebSocket URI: {gateway.url}")
|
||||
logger.info(f" Max workers: {args.max_workers}")
|
||||
logger.info(f" Pulsar host: {gateway.pulsar_host}")
|
||||
|
||||
logger.info(f" Max workers: {gateway.max_workers}")
|
||||
|
||||
try:
|
||||
asyncio.run(gateway.run())
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue