Feature/reverse gateway (#416)

* Created reverse gateway

* Dispatched invoke message translations

* Added config receiver

* Provide a script to start rev-gateway
This commit is contained in:
cybermaggedon 2025-06-24 11:19:20 +01:00 committed by GitHub
parent fcab3aeb0e
commit f08e3c1b27
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 392 additions and 7 deletions

View file

@ -71,6 +71,7 @@ setuptools.setup(
scripts=[
"scripts/agent-manager-react",
"scripts/api-gateway",
"scripts/rev-gateway",
"scripts/chunker-recursive",
"scripts/chunker-token",
"scripts/config-svc",

View file

@ -81,10 +81,11 @@ class DispatcherWrapper:
class DispatcherManager:
def __init__(self, pulsar_client, config_receiver):
def __init__(self, pulsar_client, config_receiver, prefix="api-gateway"):
self.pulsar_client = pulsar_client
self.config_receiver = config_receiver
self.config_receiver.add_handler(self)
self.prefix = prefix
self.flows = {}
self.dispatchers = {}
@ -133,8 +134,8 @@ class DispatcherManager:
dispatcher = global_dispatchers[kind](
pulsar_client = self.pulsar_client,
timeout = 120,
consumer = f"api-gateway-{kind}-request",
subscriber = f"api-gateway-{kind}-request",
consumer = f"{self.prefix}-{kind}-request",
subscriber = f"{self.prefix}-{kind}-request",
)
await dispatcher.start()
@ -226,8 +227,8 @@ class DispatcherManager:
ws = ws,
running = running,
queue = qconfig,
consumer = f"api-gateway-{id}",
subscriber = f"api-gateway-{id}",
consumer = f"{self.prefix}-{id}",
subscriber = f"{self.prefix}-{id}",
)
return dispatcher
@ -268,8 +269,8 @@ class DispatcherManager:
request_queue = qconfig["request"],
response_queue = qconfig["response"],
timeout = 120,
consumer = f"api-gateway-{flow}-{kind}-request",
subscriber = f"api-gateway-{flow}-{kind}-request",
consumer = f"{self.prefix}-{flow}-{kind}-request",
subscriber = f"{self.prefix}-{flow}-{kind}-request",
)
elif kind in sender_dispatchers:
dispatcher = sender_dispatchers[kind](

View file

@ -73,6 +73,7 @@ class Api:
self.dispatcher_manager = DispatcherManager(
pulsar_client = self.pulsar_client,
config_receiver = self.config_receiver,
prefix = "gateway",
)
self.endpoint_manager = EndpointManager(

View file

@ -0,0 +1 @@
from . service import run

View file

@ -0,0 +1,11 @@
import logging
from .service import run
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
if __name__ == "__main__":
run()

View file

@ -0,0 +1,130 @@
import asyncio
import logging
import uuid
from typing import Dict, Any, Optional
from trustgraph.messaging import TranslatorRegistry
from ..gateway.dispatch.manager import DispatcherManager
logger = logging.getLogger("dispatcher")
logger.setLevel(logging.INFO)
class WebSocketResponder:
"""Simple responder that captures response for websocket return"""
def __init__(self):
self.response = None
self.completed = False
async def send(self, data):
"""Capture the response data"""
self.response = data
self.completed = True
async def __call__(self, data, final=False):
"""Make the responder callable for compatibility with requestor"""
await self.send(data)
if final:
self.completed = True
class MessageDispatcher:
def __init__(self, max_workers: int = 10, config_receiver=None, pulsar_client=None):
self.max_workers = max_workers
self.semaphore = asyncio.Semaphore(max_workers)
self.active_tasks = set()
self.pulsar_client = pulsar_client
# Use DispatcherManager for flow and service management
if pulsar_client and config_receiver:
self.dispatcher_manager = DispatcherManager(pulsar_client, config_receiver, prefix="rev-gateway")
else:
self.dispatcher_manager = None
logger.warning("No pulsar_client or config_receiver provided - using fallback mode")
# Service name mapping from websocket protocol to translator registry
self.service_mapping = {
"text-completion": "text-completion",
"graph-rag": "graph-rag",
"agent": "agent",
"embeddings": "embeddings",
"graph-embeddings": "graph-embeddings",
"triples": "triples",
"document-load": "document",
"text-load": "text-document",
"flow": "flow",
"knowledge": "knowledge",
"config": "config",
"librarian": "librarian",
"document-rag": "document-rag"
}
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
async with self.semaphore:
task = asyncio.create_task(self._process_message(message))
self.active_tasks.add(task)
try:
result = await task
return result
finally:
self.active_tasks.discard(task)
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
request_id = message.get('id', str(uuid.uuid4()))
service = message.get('service')
request_data = message.get('request', {})
flow_id = message.get('flow', 'default') # Default flow
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
try:
if not self.dispatcher_manager:
raise RuntimeError("DispatcherManager not available - pulsar_client and config_receiver required")
# Use DispatcherManager for flow-based processing
responder = WebSocketResponder()
# Map websocket service name to dispatcher service name
dispatcher_service = self.service_mapping.get(service, service)
# Check if this is a global service or flow service
from ..gateway.dispatch.manager import global_dispatchers
if dispatcher_service in global_dispatchers:
# Use global service dispatcher
await self.dispatcher_manager.invoke_global_service(
request_data, responder, dispatcher_service
)
else:
# Use DispatcherManager to process the request through Pulsar queues
await self.dispatcher_manager.invoke_flow_service(
request_data, responder, flow_id, dispatcher_service
)
# Get the response from the responder
if responder.completed:
response_data = responder.response
else:
response_data = {'error': 'No response received'}
response = {
'id': request_id,
'response': response_data
}
except Exception as e:
logger.error(f"Error processing message {request_id}: {e}")
response = {
'id': request_id,
'response': {'error': str(e)}
}
logger.info(f"Completed processing message {request_id}")
return response
async def shutdown(self):
if self.active_tasks:
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
await asyncio.gather(*self.active_tasks, return_exceptions=True)
# DispatcherManager handles its own cleanup
logger.info("Dispatcher shutdown complete")

View file

@ -0,0 +1,240 @@
import asyncio
import argparse
import logging
import json
import sys
import os
from aiohttp import ClientSession, WSMsgType, ClientWebSocketResponse
from typing import Optional
from urllib.parse import urlparse, urlunparse
import pulsar
from .dispatcher import MessageDispatcher
from ..gateway.config.receiver import ConfigReceiver
logger = logging.getLogger("rev_gateway")
logger.setLevel(logging.INFO)
class ReverseGateway:
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
pulsar_host: str = None, pulsar_api_key: str = None,
pulsar_listener: str = None):
# Set default WebSocket URI with environment variable support
if websocket_uri is None:
websocket_uri = os.getenv("WEBSOCKET_URI", "wss://api.trustgraph.ai/ws")
# Parse and validate the WebSocket URI
parsed_uri = urlparse(websocket_uri)
if parsed_uri.scheme not in ('ws', 'wss'):
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
if not parsed_uri.netloc:
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
# Store parsed components for debugging/logging
self.websocket_uri = websocket_uri
self.host = parsed_uri.hostname
self.port = parsed_uri.port
self.scheme = parsed_uri.scheme
self.path = parsed_uri.path or "/ws"
# Construct the full URL (in case path was missing)
if not parsed_uri.path:
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
else:
self.url = websocket_uri
self.max_workers = max_workers
self.ws: Optional[ClientWebSocketResponse] = None
self.session: Optional[ClientSession] = None
self.running = False
self.reconnect_delay = 3.0
# Pulsar configuration
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
self.pulsar_listener = pulsar_listener
# Initialize Pulsar client
if self.pulsar_api_key:
self.pulsar_client = pulsar.Client(
self.pulsar_host,
listener_name=self.pulsar_listener,
authentication=pulsar.AuthenticationToken(self.pulsar_api_key)
)
else:
self.pulsar_client = pulsar.Client(
self.pulsar_host,
listener_name=self.pulsar_listener
)
# Initialize config receiver
self.config_receiver = ConfigReceiver(self.pulsar_client)
# Initialize dispatcher with config_receiver and pulsar_client - must be created after config_receiver
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.pulsar_client)
async def connect(self) -> bool:
try:
if self.session is None:
self.session = ClientSession()
logger.info(f"Connecting to {self.url}")
self.ws = await self.session.ws_connect(self.url)
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
return True
except Exception as e:
logger.error(f"Failed to connect to {self.url}: {e}")
return False
async def disconnect(self):
if self.ws and not self.ws.closed:
await self.ws.close()
if self.session and not self.session.closed:
await self.session.close()
self.ws = None
self.session = None
async def send_message(self, message: dict):
if self.ws and not self.ws.closed:
try:
await self.ws.send_str(json.dumps(message))
except Exception as e:
logger.error(f"Failed to send message: {e}")
async def handle_message(self, message: str):
try:
print(f"Received: {message}", flush=True)
msg_data = json.loads(message)
response = await self.dispatcher.handle_message(msg_data)
if response:
await self.send_message(response)
except Exception as e:
logger.error(f"Error handling message: {e}")
async def listen(self):
while self.running and self.ws and not self.ws.closed:
try:
msg = await self.ws.receive()
if msg.type == WSMsgType.TEXT:
await self.handle_message(msg.data)
elif msg.type == WSMsgType.BINARY:
await self.handle_message(msg.data.decode('utf-8'))
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
logger.warning("WebSocket closed or error occurred")
break
except Exception as e:
logger.error(f"Error in listen loop: {e}")
break
async def run(self):
self.running = True
logger.info("Starting reverse gateway")
# Start config receiver
logger.info("Starting config receiver")
await self.config_receiver.start()
while self.running:
try:
if await self.connect():
await self.listen()
else:
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
await self.disconnect()
if self.running:
await asyncio.sleep(self.reconnect_delay)
except KeyboardInterrupt:
logger.info("Shutdown requested")
break
except Exception as e:
logger.error(f"Unexpected error: {e}")
if self.running:
await asyncio.sleep(self.reconnect_delay)
await self.shutdown()
async def shutdown(self):
logger.info("Shutting down reverse gateway")
self.running = False
await self.dispatcher.shutdown()
await self.disconnect()
# Close Pulsar client
if hasattr(self, 'pulsar_client'):
self.pulsar_client.close()
def stop(self):
self.running = False
def parse_args():
parser = argparse.ArgumentParser(
prog="reverse-gateway",
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
)
parser.add_argument(
'--websocket-uri',
default=None,
help='WebSocket URI to connect to (default: wss://api.trustgraph.ai/ws or WEBSOCKET_URI env var)'
)
parser.add_argument(
'--max-workers',
type=int,
default=10,
help='Maximum concurrent message handlers (default: 10)'
)
parser.add_argument(
'--pulsar-host',
default=None,
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
)
parser.add_argument(
'--pulsar-api-key',
default=None,
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
)
parser.add_argument(
'--pulsar-listener',
default=None,
help='Pulsar listener name'
)
return parser.parse_args()
def run():
args = parse_args()
gateway = ReverseGateway(
websocket_uri=args.websocket_uri,
max_workers=args.max_workers,
pulsar_host=args.pulsar_host,
pulsar_api_key=args.pulsar_api_key,
pulsar_listener=args.pulsar_listener
)
print(f"Starting reverse gateway:")
print(f" WebSocket URI: {gateway.url}")
print(f" Max workers: {args.max_workers}")
print(f" Pulsar host: {gateway.pulsar_host}")
try:
asyncio.run(gateway.run())
except KeyboardInterrupt:
print("\nShutdown requested by user")
except Exception as e:
print(f"Fatal error: {e}")
sys.exit(1)