diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index 1cda836d..562c5389 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -71,6 +71,7 @@ setuptools.setup( scripts=[ "scripts/agent-manager-react", "scripts/api-gateway", + "scripts/rev-gateway", "scripts/chunker-recursive", "scripts/chunker-token", "scripts/config-svc", diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 189ea5e3..0b5b26f1 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -81,10 +81,11 @@ class DispatcherWrapper: class DispatcherManager: - def __init__(self, pulsar_client, config_receiver): + def __init__(self, pulsar_client, config_receiver, prefix="api-gateway"): self.pulsar_client = pulsar_client self.config_receiver = config_receiver self.config_receiver.add_handler(self) + self.prefix = prefix self.flows = {} self.dispatchers = {} @@ -133,8 +134,8 @@ class DispatcherManager: dispatcher = global_dispatchers[kind]( pulsar_client = self.pulsar_client, timeout = 120, - consumer = f"api-gateway-{kind}-request", - subscriber = f"api-gateway-{kind}-request", + consumer = f"{self.prefix}-{kind}-request", + subscriber = f"{self.prefix}-{kind}-request", ) await dispatcher.start() @@ -226,8 +227,8 @@ class DispatcherManager: ws = ws, running = running, queue = qconfig, - consumer = f"api-gateway-{id}", - subscriber = f"api-gateway-{id}", + consumer = f"{self.prefix}-{id}", + subscriber = f"{self.prefix}-{id}", ) return dispatcher @@ -268,8 +269,8 @@ class DispatcherManager: request_queue = qconfig["request"], response_queue = qconfig["response"], timeout = 120, - consumer = f"api-gateway-{flow}-{kind}-request", - subscriber = f"api-gateway-{flow}-{kind}-request", + consumer = f"{self.prefix}-{flow}-{kind}-request", + subscriber = f"{self.prefix}-{flow}-{kind}-request", ) elif kind in sender_dispatchers: dispatcher = sender_dispatchers[kind]( diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 97406422..ee66b9d3 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -73,6 +73,7 @@ class Api: self.dispatcher_manager = DispatcherManager( pulsar_client = self.pulsar_client, config_receiver = self.config_receiver, + prefix = "gateway", ) self.endpoint_manager = EndpointManager( diff --git a/trustgraph-flow/trustgraph/rev_gateway/__init__.py b/trustgraph-flow/trustgraph/rev_gateway/__init__.py new file mode 100644 index 00000000..1be89162 --- /dev/null +++ b/trustgraph-flow/trustgraph/rev_gateway/__init__.py @@ -0,0 +1 @@ +from . service import run diff --git a/trustgraph-flow/trustgraph/rev_gateway/__main__.py b/trustgraph-flow/trustgraph/rev_gateway/__main__.py new file mode 100644 index 00000000..70262bc8 --- /dev/null +++ b/trustgraph-flow/trustgraph/rev_gateway/__main__.py @@ -0,0 +1,11 @@ +import logging +from .service import run + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +if __name__ == "__main__": + run() + diff --git a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py new file mode 100644 index 00000000..03e79c0d --- /dev/null +++ b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py @@ -0,0 +1,130 @@ +import asyncio +import logging +import uuid +from typing import Dict, Any, Optional +from trustgraph.messaging import TranslatorRegistry +from ..gateway.dispatch.manager import DispatcherManager + +logger = logging.getLogger("dispatcher") +logger.setLevel(logging.INFO) + +class WebSocketResponder: + """Simple responder that captures response for websocket return""" + def __init__(self): + self.response = None + self.completed = False + + async def send(self, data): + """Capture the response data""" + self.response = data + self.completed = True + + async def __call__(self, data, final=False): + """Make the responder callable for compatibility with requestor""" + await self.send(data) + if final: + self.completed = True + +class MessageDispatcher: + + def __init__(self, max_workers: int = 10, config_receiver=None, pulsar_client=None): + self.max_workers = max_workers + self.semaphore = asyncio.Semaphore(max_workers) + self.active_tasks = set() + self.pulsar_client = pulsar_client + + # Use DispatcherManager for flow and service management + if pulsar_client and config_receiver: + self.dispatcher_manager = DispatcherManager(pulsar_client, config_receiver, prefix="rev-gateway") + else: + self.dispatcher_manager = None + logger.warning("No pulsar_client or config_receiver provided - using fallback mode") + + # Service name mapping from websocket protocol to translator registry + self.service_mapping = { + "text-completion": "text-completion", + "graph-rag": "graph-rag", + "agent": "agent", + "embeddings": "embeddings", + "graph-embeddings": "graph-embeddings", + "triples": "triples", + "document-load": "document", + "text-load": "text-document", + "flow": "flow", + "knowledge": "knowledge", + "config": "config", + "librarian": "librarian", + "document-rag": "document-rag" + } + + async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]: + async with self.semaphore: + task = asyncio.create_task(self._process_message(message)) + self.active_tasks.add(task) + + try: + result = await task + return result + finally: + self.active_tasks.discard(task) + + async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]: + request_id = message.get('id', str(uuid.uuid4())) + service = message.get('service') + request_data = message.get('request', {}) + flow_id = message.get('flow', 'default') # Default flow + + logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}") + + try: + if not self.dispatcher_manager: + raise RuntimeError("DispatcherManager not available - pulsar_client and config_receiver required") + + # Use DispatcherManager for flow-based processing + responder = WebSocketResponder() + + # Map websocket service name to dispatcher service name + dispatcher_service = self.service_mapping.get(service, service) + + # Check if this is a global service or flow service + from ..gateway.dispatch.manager import global_dispatchers + if dispatcher_service in global_dispatchers: + # Use global service dispatcher + await self.dispatcher_manager.invoke_global_service( + request_data, responder, dispatcher_service + ) + else: + # Use DispatcherManager to process the request through Pulsar queues + await self.dispatcher_manager.invoke_flow_service( + request_data, responder, flow_id, dispatcher_service + ) + + # Get the response from the responder + if responder.completed: + response_data = responder.response + else: + response_data = {'error': 'No response received'} + + response = { + 'id': request_id, + 'response': response_data + } + + except Exception as e: + logger.error(f"Error processing message {request_id}: {e}") + response = { + 'id': request_id, + 'response': {'error': str(e)} + } + + logger.info(f"Completed processing message {request_id}") + return response + + + async def shutdown(self): + if self.active_tasks: + logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete") + await asyncio.gather(*self.active_tasks, return_exceptions=True) + + # DispatcherManager handles its own cleanup + logger.info("Dispatcher shutdown complete") diff --git a/trustgraph-flow/trustgraph/rev_gateway/service.py b/trustgraph-flow/trustgraph/rev_gateway/service.py new file mode 100644 index 00000000..e6ebda9b --- /dev/null +++ b/trustgraph-flow/trustgraph/rev_gateway/service.py @@ -0,0 +1,240 @@ +import asyncio +import argparse +import logging +import json +import sys +import os +from aiohttp import ClientSession, WSMsgType, ClientWebSocketResponse +from typing import Optional +from urllib.parse import urlparse, urlunparse +import pulsar + +from .dispatcher import MessageDispatcher +from ..gateway.config.receiver import ConfigReceiver + +logger = logging.getLogger("rev_gateway") +logger.setLevel(logging.INFO) + +class ReverseGateway: + + def __init__(self, websocket_uri: str = None, max_workers: int = 10, + pulsar_host: str = None, pulsar_api_key: str = None, + pulsar_listener: str = None): + # Set default WebSocket URI with environment variable support + if websocket_uri is None: + websocket_uri = os.getenv("WEBSOCKET_URI", "wss://api.trustgraph.ai/ws") + + # Parse and validate the WebSocket URI + parsed_uri = urlparse(websocket_uri) + if parsed_uri.scheme not in ('ws', 'wss'): + raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}") + if not parsed_uri.netloc: + raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}") + + # Store parsed components for debugging/logging + self.websocket_uri = websocket_uri + self.host = parsed_uri.hostname + self.port = parsed_uri.port + self.scheme = parsed_uri.scheme + self.path = parsed_uri.path or "/ws" + + # Construct the full URL (in case path was missing) + if not parsed_uri.path: + self.url = f"{self.scheme}://{parsed_uri.netloc}/ws" + else: + self.url = websocket_uri + + self.max_workers = max_workers + self.ws: Optional[ClientWebSocketResponse] = None + self.session: Optional[ClientSession] = None + self.running = False + self.reconnect_delay = 3.0 + + # Pulsar configuration + self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") + self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None) + self.pulsar_listener = pulsar_listener + + # Initialize Pulsar client + if self.pulsar_api_key: + self.pulsar_client = pulsar.Client( + self.pulsar_host, + listener_name=self.pulsar_listener, + authentication=pulsar.AuthenticationToken(self.pulsar_api_key) + ) + else: + self.pulsar_client = pulsar.Client( + self.pulsar_host, + listener_name=self.pulsar_listener + ) + + # Initialize config receiver + self.config_receiver = ConfigReceiver(self.pulsar_client) + + # Initialize dispatcher with config_receiver and pulsar_client - must be created after config_receiver + self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.pulsar_client) + + async def connect(self) -> bool: + try: + if self.session is None: + self.session = ClientSession() + + logger.info(f"Connecting to {self.url}") + self.ws = await self.session.ws_connect(self.url) + logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}") + return True + + except Exception as e: + logger.error(f"Failed to connect to {self.url}: {e}") + return False + + async def disconnect(self): + if self.ws and not self.ws.closed: + await self.ws.close() + if self.session and not self.session.closed: + await self.session.close() + self.ws = None + self.session = None + + async def send_message(self, message: dict): + if self.ws and not self.ws.closed: + try: + await self.ws.send_str(json.dumps(message)) + except Exception as e: + logger.error(f"Failed to send message: {e}") + + async def handle_message(self, message: str): + try: + print(f"Received: {message}", flush=True) + + msg_data = json.loads(message) + response = await self.dispatcher.handle_message(msg_data) + + if response: + await self.send_message(response) + + except Exception as e: + logger.error(f"Error handling message: {e}") + + async def listen(self): + while self.running and self.ws and not self.ws.closed: + try: + msg = await self.ws.receive() + + if msg.type == WSMsgType.TEXT: + await self.handle_message(msg.data) + elif msg.type == WSMsgType.BINARY: + await self.handle_message(msg.data.decode('utf-8')) + elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): + logger.warning("WebSocket closed or error occurred") + break + + except Exception as e: + logger.error(f"Error in listen loop: {e}") + break + + async def run(self): + self.running = True + logger.info("Starting reverse gateway") + + # Start config receiver + logger.info("Starting config receiver") + await self.config_receiver.start() + + while self.running: + try: + if await self.connect(): + await self.listen() + else: + logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds") + + await self.disconnect() + + if self.running: + await asyncio.sleep(self.reconnect_delay) + + except KeyboardInterrupt: + logger.info("Shutdown requested") + break + except Exception as e: + logger.error(f"Unexpected error: {e}") + if self.running: + await asyncio.sleep(self.reconnect_delay) + + await self.shutdown() + + async def shutdown(self): + logger.info("Shutting down reverse gateway") + self.running = False + await self.dispatcher.shutdown() + await self.disconnect() + + # Close Pulsar client + if hasattr(self, 'pulsar_client'): + self.pulsar_client.close() + + def stop(self): + self.running = False + +def parse_args(): + parser = argparse.ArgumentParser( + prog="reverse-gateway", + description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge" + ) + + parser.add_argument( + '--websocket-uri', + default=None, + help='WebSocket URI to connect to (default: wss://api.trustgraph.ai/ws or WEBSOCKET_URI env var)' + ) + + parser.add_argument( + '--max-workers', + type=int, + default=10, + help='Maximum concurrent message handlers (default: 10)' + ) + + parser.add_argument( + '--pulsar-host', + default=None, + help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)' + ) + + parser.add_argument( + '--pulsar-api-key', + default=None, + help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)' + ) + + parser.add_argument( + '--pulsar-listener', + default=None, + help='Pulsar listener name' + ) + + return parser.parse_args() + +def run(): + args = parse_args() + + gateway = ReverseGateway( + websocket_uri=args.websocket_uri, + max_workers=args.max_workers, + pulsar_host=args.pulsar_host, + pulsar_api_key=args.pulsar_api_key, + pulsar_listener=args.pulsar_listener + ) + + print(f"Starting reverse gateway:") + print(f" WebSocket URI: {gateway.url}") + print(f" Max workers: {args.max_workers}") + print(f" Pulsar host: {gateway.pulsar_host}") + + try: + asyncio.run(gateway.run()) + except KeyboardInterrupt: + print("\nShutdown requested by user") + except Exception as e: + print(f"Fatal error: {e}") + sys.exit(1)