mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Feature/reverse gateway (#416)
* Created reverse gateway * Dispatched invoke message translations * Added config receiver * Provide a script to start rev-gateway
This commit is contained in:
parent
fcab3aeb0e
commit
f08e3c1b27
7 changed files with 392 additions and 7 deletions
|
|
@ -71,6 +71,7 @@ setuptools.setup(
|
|||
scripts=[
|
||||
"scripts/agent-manager-react",
|
||||
"scripts/api-gateway",
|
||||
"scripts/rev-gateway",
|
||||
"scripts/chunker-recursive",
|
||||
"scripts/chunker-token",
|
||||
"scripts/config-svc",
|
||||
|
|
|
|||
|
|
@ -81,10 +81,11 @@ class DispatcherWrapper:
|
|||
|
||||
class DispatcherManager:
|
||||
|
||||
def __init__(self, pulsar_client, config_receiver):
|
||||
def __init__(self, pulsar_client, config_receiver, prefix="api-gateway"):
|
||||
self.pulsar_client = pulsar_client
|
||||
self.config_receiver = config_receiver
|
||||
self.config_receiver.add_handler(self)
|
||||
self.prefix = prefix
|
||||
|
||||
self.flows = {}
|
||||
self.dispatchers = {}
|
||||
|
|
@ -133,8 +134,8 @@ class DispatcherManager:
|
|||
dispatcher = global_dispatchers[kind](
|
||||
pulsar_client = self.pulsar_client,
|
||||
timeout = 120,
|
||||
consumer = f"api-gateway-{kind}-request",
|
||||
subscriber = f"api-gateway-{kind}-request",
|
||||
consumer = f"{self.prefix}-{kind}-request",
|
||||
subscriber = f"{self.prefix}-{kind}-request",
|
||||
)
|
||||
|
||||
await dispatcher.start()
|
||||
|
|
@ -226,8 +227,8 @@ class DispatcherManager:
|
|||
ws = ws,
|
||||
running = running,
|
||||
queue = qconfig,
|
||||
consumer = f"api-gateway-{id}",
|
||||
subscriber = f"api-gateway-{id}",
|
||||
consumer = f"{self.prefix}-{id}",
|
||||
subscriber = f"{self.prefix}-{id}",
|
||||
)
|
||||
|
||||
return dispatcher
|
||||
|
|
@ -268,8 +269,8 @@ class DispatcherManager:
|
|||
request_queue = qconfig["request"],
|
||||
response_queue = qconfig["response"],
|
||||
timeout = 120,
|
||||
consumer = f"api-gateway-{flow}-{kind}-request",
|
||||
subscriber = f"api-gateway-{flow}-{kind}-request",
|
||||
consumer = f"{self.prefix}-{flow}-{kind}-request",
|
||||
subscriber = f"{self.prefix}-{flow}-{kind}-request",
|
||||
)
|
||||
elif kind in sender_dispatchers:
|
||||
dispatcher = sender_dispatchers[kind](
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class Api:
|
|||
self.dispatcher_manager = DispatcherManager(
|
||||
pulsar_client = self.pulsar_client,
|
||||
config_receiver = self.config_receiver,
|
||||
prefix = "gateway",
|
||||
)
|
||||
|
||||
self.endpoint_manager = EndpointManager(
|
||||
|
|
|
|||
1
trustgraph-flow/trustgraph/rev_gateway/__init__.py
Normal file
1
trustgraph-flow/trustgraph/rev_gateway/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from . service import run
|
||||
11
trustgraph-flow/trustgraph/rev_gateway/__main__.py
Normal file
11
trustgraph-flow/trustgraph/rev_gateway/__main__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import logging
|
||||
from .service import run
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
||||
130
trustgraph-flow/trustgraph/rev_gateway/dispatcher.py
Normal file
130
trustgraph-flow/trustgraph/rev_gateway/dispatcher.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from trustgraph.messaging import TranslatorRegistry
|
||||
from ..gateway.dispatch.manager import DispatcherManager
|
||||
|
||||
logger = logging.getLogger("dispatcher")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class WebSocketResponder:
|
||||
"""Simple responder that captures response for websocket return"""
|
||||
def __init__(self):
|
||||
self.response = None
|
||||
self.completed = False
|
||||
|
||||
async def send(self, data):
|
||||
"""Capture the response data"""
|
||||
self.response = data
|
||||
self.completed = True
|
||||
|
||||
async def __call__(self, data, final=False):
|
||||
"""Make the responder callable for compatibility with requestor"""
|
||||
await self.send(data)
|
||||
if final:
|
||||
self.completed = True
|
||||
|
||||
class MessageDispatcher:
|
||||
|
||||
def __init__(self, max_workers: int = 10, config_receiver=None, pulsar_client=None):
|
||||
self.max_workers = max_workers
|
||||
self.semaphore = asyncio.Semaphore(max_workers)
|
||||
self.active_tasks = set()
|
||||
self.pulsar_client = pulsar_client
|
||||
|
||||
# Use DispatcherManager for flow and service management
|
||||
if pulsar_client and config_receiver:
|
||||
self.dispatcher_manager = DispatcherManager(pulsar_client, config_receiver, prefix="rev-gateway")
|
||||
else:
|
||||
self.dispatcher_manager = None
|
||||
logger.warning("No pulsar_client or config_receiver provided - using fallback mode")
|
||||
|
||||
# Service name mapping from websocket protocol to translator registry
|
||||
self.service_mapping = {
|
||||
"text-completion": "text-completion",
|
||||
"graph-rag": "graph-rag",
|
||||
"agent": "agent",
|
||||
"embeddings": "embeddings",
|
||||
"graph-embeddings": "graph-embeddings",
|
||||
"triples": "triples",
|
||||
"document-load": "document",
|
||||
"text-load": "text-document",
|
||||
"flow": "flow",
|
||||
"knowledge": "knowledge",
|
||||
"config": "config",
|
||||
"librarian": "librarian",
|
||||
"document-rag": "document-rag"
|
||||
}
|
||||
|
||||
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
||||
async with self.semaphore:
|
||||
task = asyncio.create_task(self._process_message(message))
|
||||
self.active_tasks.add(task)
|
||||
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
self.active_tasks.discard(task)
|
||||
|
||||
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
request_id = message.get('id', str(uuid.uuid4()))
|
||||
service = message.get('service')
|
||||
request_data = message.get('request', {})
|
||||
flow_id = message.get('flow', 'default') # Default flow
|
||||
|
||||
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
|
||||
|
||||
try:
|
||||
if not self.dispatcher_manager:
|
||||
raise RuntimeError("DispatcherManager not available - pulsar_client and config_receiver required")
|
||||
|
||||
# Use DispatcherManager for flow-based processing
|
||||
responder = WebSocketResponder()
|
||||
|
||||
# Map websocket service name to dispatcher service name
|
||||
dispatcher_service = self.service_mapping.get(service, service)
|
||||
|
||||
# Check if this is a global service or flow service
|
||||
from ..gateway.dispatch.manager import global_dispatchers
|
||||
if dispatcher_service in global_dispatchers:
|
||||
# Use global service dispatcher
|
||||
await self.dispatcher_manager.invoke_global_service(
|
||||
request_data, responder, dispatcher_service
|
||||
)
|
||||
else:
|
||||
# Use DispatcherManager to process the request through Pulsar queues
|
||||
await self.dispatcher_manager.invoke_flow_service(
|
||||
request_data, responder, flow_id, dispatcher_service
|
||||
)
|
||||
|
||||
# Get the response from the responder
|
||||
if responder.completed:
|
||||
response_data = responder.response
|
||||
else:
|
||||
response_data = {'error': 'No response received'}
|
||||
|
||||
response = {
|
||||
'id': request_id,
|
||||
'response': response_data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message {request_id}: {e}")
|
||||
response = {
|
||||
'id': request_id,
|
||||
'response': {'error': str(e)}
|
||||
}
|
||||
|
||||
logger.info(f"Completed processing message {request_id}")
|
||||
return response
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
if self.active_tasks:
|
||||
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
|
||||
await asyncio.gather(*self.active_tasks, return_exceptions=True)
|
||||
|
||||
# DispatcherManager handles its own cleanup
|
||||
logger.info("Dispatcher shutdown complete")
|
||||
240
trustgraph-flow/trustgraph/rev_gateway/service.py
Normal file
240
trustgraph-flow/trustgraph/rev_gateway/service.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from aiohttp import ClientSession, WSMsgType, ClientWebSocketResponse
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
import pulsar
|
||||
|
||||
from .dispatcher import MessageDispatcher
|
||||
from ..gateway.config.receiver import ConfigReceiver
|
||||
|
||||
logger = logging.getLogger("rev_gateway")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ReverseGateway:
|
||||
|
||||
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
|
||||
pulsar_host: str = None, pulsar_api_key: str = None,
|
||||
pulsar_listener: str = None):
|
||||
# Set default WebSocket URI with environment variable support
|
||||
if websocket_uri is None:
|
||||
websocket_uri = os.getenv("WEBSOCKET_URI", "wss://api.trustgraph.ai/ws")
|
||||
|
||||
# Parse and validate the WebSocket URI
|
||||
parsed_uri = urlparse(websocket_uri)
|
||||
if parsed_uri.scheme not in ('ws', 'wss'):
|
||||
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
|
||||
if not parsed_uri.netloc:
|
||||
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
|
||||
|
||||
# Store parsed components for debugging/logging
|
||||
self.websocket_uri = websocket_uri
|
||||
self.host = parsed_uri.hostname
|
||||
self.port = parsed_uri.port
|
||||
self.scheme = parsed_uri.scheme
|
||||
self.path = parsed_uri.path or "/ws"
|
||||
|
||||
# Construct the full URL (in case path was missing)
|
||||
if not parsed_uri.path:
|
||||
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
|
||||
else:
|
||||
self.url = websocket_uri
|
||||
|
||||
self.max_workers = max_workers
|
||||
self.ws: Optional[ClientWebSocketResponse] = None
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.running = False
|
||||
self.reconnect_delay = 3.0
|
||||
|
||||
# Pulsar configuration
|
||||
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
|
||||
self.pulsar_listener = pulsar_listener
|
||||
|
||||
# Initialize Pulsar client
|
||||
if self.pulsar_api_key:
|
||||
self.pulsar_client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
listener_name=self.pulsar_listener,
|
||||
authentication=pulsar.AuthenticationToken(self.pulsar_api_key)
|
||||
)
|
||||
else:
|
||||
self.pulsar_client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
listener_name=self.pulsar_listener
|
||||
)
|
||||
|
||||
# Initialize config receiver
|
||||
self.config_receiver = ConfigReceiver(self.pulsar_client)
|
||||
|
||||
# Initialize dispatcher with config_receiver and pulsar_client - must be created after config_receiver
|
||||
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.pulsar_client)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
try:
|
||||
if self.session is None:
|
||||
self.session = ClientSession()
|
||||
|
||||
logger.info(f"Connecting to {self.url}")
|
||||
self.ws = await self.session.ws_connect(self.url)
|
||||
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {self.url}: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
self.ws = None
|
||||
self.session = None
|
||||
|
||||
async def send_message(self, message: dict):
|
||||
if self.ws and not self.ws.closed:
|
||||
try:
|
||||
await self.ws.send_str(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message: {e}")
|
||||
|
||||
async def handle_message(self, message: str):
|
||||
try:
|
||||
print(f"Received: {message}", flush=True)
|
||||
|
||||
msg_data = json.loads(message)
|
||||
response = await self.dispatcher.handle_message(msg_data)
|
||||
|
||||
if response:
|
||||
await self.send_message(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
|
||||
async def listen(self):
|
||||
while self.running and self.ws and not self.ws.closed:
|
||||
try:
|
||||
msg = await self.ws.receive()
|
||||
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await self.handle_message(msg.data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await self.handle_message(msg.data.decode('utf-8'))
|
||||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
|
||||
logger.warning("WebSocket closed or error occurred")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen loop: {e}")
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
self.running = True
|
||||
logger.info("Starting reverse gateway")
|
||||
|
||||
# Start config receiver
|
||||
logger.info("Starting config receiver")
|
||||
await self.config_receiver.start()
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if await self.connect():
|
||||
await self.listen()
|
||||
else:
|
||||
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
|
||||
|
||||
await self.disconnect()
|
||||
|
||||
if self.running:
|
||||
await asyncio.sleep(self.reconnect_delay)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutdown requested")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
if self.running:
|
||||
await asyncio.sleep(self.reconnect_delay)
|
||||
|
||||
await self.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Shutting down reverse gateway")
|
||||
self.running = False
|
||||
await self.dispatcher.shutdown()
|
||||
await self.disconnect()
|
||||
|
||||
# Close Pulsar client
|
||||
if hasattr(self, 'pulsar_client'):
|
||||
self.pulsar_client.close()
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="reverse-gateway",
|
||||
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--websocket-uri',
|
||||
default=None,
|
||||
help='WebSocket URI to connect to (default: wss://api.trustgraph.ai/ws or WEBSOCKET_URI env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-workers',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum concurrent message handlers (default: 10)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-host',
|
||||
default=None,
|
||||
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-api-key',
|
||||
default=None,
|
||||
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-listener',
|
||||
default=None,
|
||||
help='Pulsar listener name'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def run():
|
||||
args = parse_args()
|
||||
|
||||
gateway = ReverseGateway(
|
||||
websocket_uri=args.websocket_uri,
|
||||
max_workers=args.max_workers,
|
||||
pulsar_host=args.pulsar_host,
|
||||
pulsar_api_key=args.pulsar_api_key,
|
||||
pulsar_listener=args.pulsar_listener
|
||||
)
|
||||
|
||||
print(f"Starting reverse gateway:")
|
||||
print(f" WebSocket URI: {gateway.url}")
|
||||
print(f" Max workers: {args.max_workers}")
|
||||
print(f" Pulsar host: {gateway.pulsar_host}")
|
||||
|
||||
try:
|
||||
asyncio.run(gateway.run())
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown requested by user")
|
||||
except Exception as e:
|
||||
print(f"Fatal error: {e}")
|
||||
sys.exit(1)
|
||||
Loading…
Add table
Add a link
Reference in a new issue