diff --git a/trustgraph-flow/trustgraph/gateway/mux.py b/trustgraph-flow/trustgraph/gateway/mux.py index 74797069..ae699ae6 100644 --- a/trustgraph-flow/trustgraph/gateway/mux.py +++ b/trustgraph-flow/trustgraph/gateway/mux.py @@ -8,6 +8,13 @@ from aiohttp import web, WSMsgType from . socket import SocketEndpoint from . text_completion import TextCompletionRequestor +MAX_OUTSTANDING_REQUESTS = 15 +WORKER_CLOSE_WAIT = 0.01 +START_REQUEST_WAIT = 0.1 + +# This buffers requests until task start, so short-lived +MAX_QUEUE_SIZE = 10 + class MuxEndpoint(SocketEndpoint): def __init__( @@ -20,53 +27,113 @@ class MuxEndpoint(SocketEndpoint): endpoint_path=path, auth=auth, ) - self.q = asyncio.Queue(maxsize=10) - self.services = services async def start(self): pass - async def async_thread(self, ws, running): + async def maybe_tidy_workers(self, workers): + + while True: + + try: + + await asyncio.wait_for( + asyncio.shield(workers[0]), + WORKER_CLOSE_WAIT + ) + + # worker[0] now stopped + # FIXME: Delete reference??? + + workers.pop(0) + + if len(workers) == 0: + break + + # Loop iterates to try the next worker + + except TimeoutError: + # worker[0] still running, move on + break + + async def start_request_task(self, ws, id, svc, request, workers): + + if svc not in self.services: + await ws.send_json({"id": id, "error": "Service not recognised"}) + return + + requestor = self.services[svc] + + async def responder(resp, fin): + await ws.send_json({ + "id": id, + "response": resp, + "complete": fin, + }) + + # Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS + while len(workers) > MAX_OUTSTANDING_REQUESTS: + + # Fixes deadlock + # FIXME: Put it in its own loop + await asyncio.sleep(START_REQUEST_WAIT) + + await self.maybe_tidy_workers(workers) + + worker = asyncio.create_task( + requestor.process(request, responder) + ) + + workers.append(worker) + + async def async_thread(self, ws, running, q): + + # Worker threads, servicing + workers = [] while running.get(): try: - id, svc, request = await asyncio.wait_for(self.q.get(), 1) + + if len(workers) > 0: + await self.maybe_tidy_workers(workers) + + # Get next request on queue + id, svc, request = await asyncio.wait_for(q.get(), 1) + except TimeoutError: continue + except Exception as e: + # This is an internal working error, may not be recoverable + print("Exception:", e) await ws.send_json({"id": id, "error": str(e)}) + break try: - - print(svc, request) - - requestor = self.services[svc] - - async def responder(resp, fin): - await ws.send_json({ - "id": id, - "response": resp, - "complete": fin, - }) - - resp = await requestor.process(request, responder) + print(id, svc, request) + await self.start_request_task(ws, id, svc, request, workers) except Exception as e: - + print("Exception2:", e) await ws.send_json({"error": str(e)}) running.stop() async def listener(self, ws, running): + + # The outstanding request queue, max size is MAX_QUEUE_SIZE + q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) + + async_task = asyncio.create_task(self.async_thread( + ws, running, q + )) async for msg in ws: # On error, finish - if msg.type == WSMsgType.ERROR: - break - else: + if msg.type == WSMsgType.TEXT: try: @@ -81,7 +148,7 @@ class MuxEndpoint(SocketEndpoint): if "id" not in data: raise RuntimeError("Bad message") - await self.q.put( + await q.put( (data["id"], data["service"], data["request"]) ) @@ -90,5 +157,13 @@ class MuxEndpoint(SocketEndpoint): await ws.send_json({"error": str(e)}) continue + elif msg.type == WSMsgType.ERROR: + break + elif msg.type == WSMsgType.CLOSE: + break + else: + break + running.stop() + await async_task diff --git a/trustgraph-flow/trustgraph/gateway/socket.py b/trustgraph-flow/trustgraph/gateway/socket.py index 869792b7..fd408d7b 100644 --- a/trustgraph-flow/trustgraph/gateway/socket.py +++ b/trustgraph-flow/trustgraph/gateway/socket.py @@ -22,25 +22,16 @@ class SocketEndpoint: async for msg in ws: # On error, finish - if msg.type == WSMsgType.ERROR: - break + if msg.type == WSMsgType.TEXT: + # Ignore incoming message + continue + elif msg.type == WSMsgType.BINARY: + # Ignore incoming message + continue else: - # Ignore incoming messages - pass + break running.stop() - - async def async_thread(self, ws, running): - - while running.get(): - try: - await asyncio.sleep(1) - - except TimeoutError: - continue - - except Exception as e: - print(f"Exception: {str(e)}", flush=True) async def handle(self, request): @@ -56,12 +47,8 @@ class SocketEndpoint: ws = web.WebSocketResponse() await ws.prepare(request) - task = asyncio.create_task(self.async_thread(ws, running)) - try: - await self.listener(ws, running) - except Exception as e: print(e, flush=True) @@ -69,8 +56,6 @@ class SocketEndpoint: await ws.close() - await task - return ws async def start(self):