Fix import export graceful shutdown (#476)

* Tech spec for graceful shutdown

* Graceful shutdown of importers/exporters

* Update socket to include graceful shutdown orchestration

* Adding tests for conditions tracked in this PR
This commit is contained in:
cybermaggedon 2025-08-28 13:39:28 +01:00 committed by GitHub
parent 4361e8ccca
commit 96c2b73457
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2668 additions and 193 deletions

View file

@ -26,46 +26,66 @@ class DocumentEmbeddingsExport:
self.subscriber = subscriber
async def destroy(self):
# Step 1: Signal stop to prevent new messages
self.running.stop()
await self.ws.close()
# Step 2: Wait briefly for in-flight messages
await asyncio.sleep(0.5)
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
if hasattr(self, 'subs'):
await self.subs.unsubscribe_all(self.id)
await self.subs.stop()
# Step 4: Close websocket last
if self.ws and not self.ws.closed:
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = DocumentEmbeddings
"""Enhanced run with better error handling"""
self.subs = Subscriber(
client = self.pulsar_client,
topic = self.queue,
consumer_name = self.consumer,
subscription = self.subscriber,
schema = DocumentEmbeddings,
backpressure_strategy = "block" # Configurable
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
await self.subs.start()
self.id = str(uuid.uuid4())
q = await self.subs.subscribe_all(self.id)
consecutive_errors = 0
max_consecutive_errors = 5
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
consecutive_errors = 0 # Reset on success
except asyncio.TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
logger.error(f"Exception: {str(e)}", exc_info=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()
logger.error(f"Exception sending to websocket: {str(e)}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error("Too many consecutive errors, shutting down")
break
# Brief pause before retry
await asyncio.sleep(0.1)
# Graceful cleanup handled in destroy()

View file

@ -1,6 +1,7 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
@ -8,6 +9,9 @@ from ... schema import DocumentEmbeddings, ChunkEmbeddings
from ... base import Publisher
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
# Module logger
logger = logging.getLogger(__name__)
class DocumentEmbeddingsImport:
def __init__(
@ -26,13 +30,17 @@ class DocumentEmbeddingsImport:
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()

View file

@ -26,46 +26,66 @@ class EntityContextsExport:
self.subscriber = subscriber
async def destroy(self):
# Step 1: Signal stop to prevent new messages
self.running.stop()
await self.ws.close()
# Step 2: Wait briefly for in-flight messages
await asyncio.sleep(0.5)
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
if hasattr(self, 'subs'):
await self.subs.unsubscribe_all(self.id)
await self.subs.stop()
# Step 4: Close websocket last
if self.ws and not self.ws.closed:
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = EntityContexts
"""Enhanced run with better error handling"""
self.subs = Subscriber(
client = self.pulsar_client,
topic = self.queue,
consumer_name = self.consumer,
subscription = self.subscriber,
schema = EntityContexts,
backpressure_strategy = "block" # Configurable
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
await self.subs.start()
self.id = str(uuid.uuid4())
q = await self.subs.subscribe_all(self.id)
consecutive_errors = 0
max_consecutive_errors = 5
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_entity_contexts(resp))
except TimeoutError:
consecutive_errors = 0 # Reset on success
except asyncio.TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
logger.error(f"Exception: {str(e)}", exc_info=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()
logger.error(f"Exception sending to websocket: {str(e)}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error("Too many consecutive errors, shutting down")
break
# Brief pause before retry
await asyncio.sleep(0.1)
# Graceful cleanup handled in destroy()

View file

@ -1,6 +1,7 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
@ -9,6 +10,9 @@ from ... base import Publisher
from . serialize import to_subgraph, to_value
# Module logger
logger = logging.getLogger(__name__)
class EntityContextsImport:
def __init__(
@ -26,13 +30,17 @@ class EntityContextsImport:
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()

View file

@ -26,46 +26,66 @@ class GraphEmbeddingsExport:
self.subscriber = subscriber
async def destroy(self):
# Step 1: Signal stop to prevent new messages
self.running.stop()
await self.ws.close()
# Step 2: Wait briefly for in-flight messages
await asyncio.sleep(0.5)
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
if hasattr(self, 'subs'):
await self.subs.unsubscribe_all(self.id)
await self.subs.stop()
# Step 4: Close websocket last
if self.ws and not self.ws.closed:
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = GraphEmbeddings
"""Enhanced run with better error handling"""
self.subs = Subscriber(
client = self.pulsar_client,
topic = self.queue,
consumer_name = self.consumer,
subscription = self.subscriber,
schema = GraphEmbeddings,
backpressure_strategy = "block" # Configurable
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
await self.subs.start()
self.id = str(uuid.uuid4())
q = await self.subs.subscribe_all(self.id)
consecutive_errors = 0
max_consecutive_errors = 5
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
consecutive_errors = 0 # Reset on success
except asyncio.TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
logger.error(f"Exception: {str(e)}", exc_info=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()
logger.error(f"Exception sending to websocket: {str(e)}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error("Too many consecutive errors, shutting down")
break
# Brief pause before retry
await asyncio.sleep(0.1)
# Graceful cleanup handled in destroy()

View file

@ -1,6 +1,7 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
@ -9,6 +10,9 @@ from ... base import Publisher
from . serialize import to_subgraph, to_value
# Module logger
logger = logging.getLogger(__name__)
class GraphEmbeddingsImport:
def __init__(
@ -26,13 +30,17 @@ class GraphEmbeddingsImport:
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()

View file

@ -26,46 +26,66 @@ class TriplesExport:
self.subscriber = subscriber
async def destroy(self):
# Step 1: Signal stop to prevent new messages
self.running.stop()
await self.ws.close()
# Step 2: Wait briefly for in-flight messages
await asyncio.sleep(0.5)
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
if hasattr(self, 'subs'):
await self.subs.unsubscribe_all(self.id)
await self.subs.stop()
# Step 4: Close websocket last
if self.ws and not self.ws.closed:
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = Triples
"""Enhanced run with better error handling"""
self.subs = Subscriber(
client = self.pulsar_client,
topic = self.queue,
consumer_name = self.consumer,
subscription = self.subscriber,
schema = Triples,
backpressure_strategy = "block" # Configurable
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
await self.subs.start()
self.id = str(uuid.uuid4())
q = await self.subs.subscribe_all(self.id)
consecutive_errors = 0
max_consecutive_errors = 5
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_triples(resp))
except TimeoutError:
consecutive_errors = 0 # Reset on success
except asyncio.TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
logger.error(f"Exception: {str(e)}", exc_info=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()
logger.error(f"Exception sending to websocket: {str(e)}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error("Too many consecutive errors, shutting down")
break
# Brief pause before retry
await asyncio.sleep(0.1)
# Graceful cleanup handled in destroy()

View file

@ -1,6 +1,7 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
@ -9,6 +10,9 @@ from ... base import Publisher
from . serialize import to_subgraph
# Module logger
logger = logging.getLogger(__name__)
class TriplesImport:
def __init__(
@ -26,13 +30,17 @@ class TriplesImport:
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()

View file

@ -25,24 +25,43 @@ class SocketEndpoint:
await dispatcher.run()
async def listener(self, ws, dispatcher, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
await dispatcher.receive(msg)
continue
elif msg.type == WSMsgType.BINARY:
await dispatcher.receive(msg)
continue
"""Enhanced listener with graceful shutdown"""
try:
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
await dispatcher.receive(msg)
continue
elif msg.type == WSMsgType.BINARY:
await dispatcher.receive(msg)
continue
else:
# Graceful shutdown on close
logger.info("Websocket closing, initiating graceful shutdown")
running.stop()
# Allow time for dispatcher cleanup
await asyncio.sleep(1.0)
# Close websocket if not already closed
if not ws.closed:
await ws.close()
break
else:
break
running.stop()
await ws.close()
# This executes when the async for loop completes normally (no break)
logger.debug("Websocket iteration completed, performing cleanup")
running.stop()
if not ws.closed:
await ws.close()
except Exception:
# Handle exceptions and cleanup
running.stop()
if not ws.closed:
await ws.close()
raise
async def handle(self, request):
"""Enhanced handler with better cleanup"""
try:
token = request.query['token']
except:
@ -55,7 +74,9 @@ class SocketEndpoint:
ws = web.WebSocketResponse(max_msg_size=52428800)
await ws.prepare(request)
dispatcher = None
try:
async with asyncio.TaskGroup() as tg:
@ -80,9 +101,6 @@ class SocketEndpoint:
logger.debug("Task group closed")
# Finally?
await dispatcher.destroy()
except ExceptionGroup as e:
logger.error("Exception group occurred:", exc_info=True)
@ -90,11 +108,34 @@ class SocketEndpoint:
for se in e.exceptions:
logger.error(f" Exception type: {type(se)}")
logger.error(f" Exception: {se}")
# Attempt graceful dispatcher shutdown
if dispatcher and hasattr(dispatcher, 'destroy'):
try:
await asyncio.wait_for(
dispatcher.destroy(),
timeout=5.0
)
except asyncio.TimeoutError:
logger.warning("Dispatcher shutdown timed out")
except Exception as de:
logger.error(f"Error during dispatcher cleanup: {de}")
except Exception as e:
logger.error(f"Socket exception: {e}", exc_info=True)
await ws.close()
finally:
# Ensure dispatcher cleanup
if dispatcher and hasattr(dispatcher, 'destroy'):
try:
await dispatcher.destroy()
except Exception as de:
logger.error(f"Error in final dispatcher cleanup: {de}")
# Ensure websocket is closed
if ws and not ws.closed:
await ws.close()
return ws
async def start(self):