diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index c912a460..f04f6054 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -25,7 +25,7 @@ class SocketEndpoint: await dispatcher.run() async def listener(self, ws, dispatcher, running): - + """Enhanced listener with graceful shutdown""" async for msg in ws: # On error, finish @@ -36,13 +36,16 @@ class SocketEndpoint: await dispatcher.receive(msg) continue else: + # Graceful shutdown on close + logger.info("Websocket closing, initiating graceful shutdown") + running.stop() + + # Allow time for dispatcher cleanup + await asyncio.sleep(1.0) break - - running.stop() - await ws.close() async def handle(self, request): - + """Enhanced handler with better cleanup""" try: token = request.query['token'] except: @@ -55,7 +58,9 @@ class SocketEndpoint: ws = web.WebSocketResponse(max_msg_size=52428800) await ws.prepare(request) - + + dispatcher = None + try: async with asyncio.TaskGroup() as tg: @@ -80,9 +85,6 @@ class SocketEndpoint: logger.debug("Task group closed") - # Finally? - await dispatcher.destroy() - except ExceptionGroup as e: logger.error("Exception group occurred:", exc_info=True) @@ -90,11 +92,34 @@ class SocketEndpoint: for se in e.exceptions: logger.error(f" Exception type: {type(se)}") logger.error(f" Exception: {se}") + + # Attempt graceful dispatcher shutdown + if dispatcher and hasattr(dispatcher, 'destroy'): + try: + await asyncio.wait_for( + dispatcher.destroy(), + timeout=5.0 + ) + except asyncio.TimeoutError: + logger.warning("Dispatcher shutdown timed out") + except Exception as de: + logger.error(f"Error during dispatcher cleanup: {de}") + except Exception as e: logger.error(f"Socket exception: {e}", exc_info=True) - - await ws.close() - + + finally: + # Ensure dispatcher cleanup + if dispatcher and hasattr(dispatcher, 'destroy'): + try: + await dispatcher.destroy() + except Exception as de: + logger.error(f"Error in final dispatcher cleanup: {de}") + + # Ensure websocket is closed + if ws and not ws.closed: + await ws.close() + return ws async def start(self):