mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-10 07:42:38 +02:00
Fix import export graceful shutdown (#476)
* Tech spec for graceful shutdown * Graceful shutdown of importers/exporters * Update socket to include graceful shutdown orchestration * Adding tests for conditions tracked in this PR
This commit is contained in:
parent
4361e8ccca
commit
96c2b73457
17 changed files with 2668 additions and 193 deletions
|
|
@ -26,46 +26,66 @@ class DocumentEmbeddingsExport:
|
|||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Signal stop to prevent new messages
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
# Step 2: Wait briefly for in-flight messages
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||
if hasattr(self, 'subs'):
|
||||
await self.subs.unsubscribe_all(self.id)
|
||||
await self.subs.stop()
|
||||
|
||||
# Step 4: Close websocket last
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = DocumentEmbeddings
|
||||
"""Enhanced run with better error handling"""
|
||||
self.subs = Subscriber(
|
||||
client = self.pulsar_client,
|
||||
topic = self.queue,
|
||||
consumer_name = self.consumer,
|
||||
subscription = self.subscriber,
|
||||
schema = DocumentEmbeddings,
|
||||
backpressure_strategy = "block" # Configurable
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
|
||||
await self.subs.start()
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
q = await self.subs.subscribe_all(self.id)
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_document_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
consecutive_errors = 0 # Reset on success
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||
consecutive_errors += 1
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("Too many consecutive errors, shutting down")
|
||||
break
|
||||
|
||||
# Brief pause before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Graceful cleanup handled in destroy()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
|
|
@ -8,6 +9,9 @@ from ... schema import DocumentEmbeddings, ChunkEmbeddings
|
|||
from ... base import Publisher
|
||||
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentEmbeddingsImport:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -26,13 +30,17 @@ class DocumentEmbeddingsImport:
|
|||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
|
|
|||
|
|
@ -26,46 +26,66 @@ class EntityContextsExport:
|
|||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Signal stop to prevent new messages
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
# Step 2: Wait briefly for in-flight messages
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||
if hasattr(self, 'subs'):
|
||||
await self.subs.unsubscribe_all(self.id)
|
||||
await self.subs.stop()
|
||||
|
||||
# Step 4: Close websocket last
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = EntityContexts
|
||||
"""Enhanced run with better error handling"""
|
||||
self.subs = Subscriber(
|
||||
client = self.pulsar_client,
|
||||
topic = self.queue,
|
||||
consumer_name = self.consumer,
|
||||
subscription = self.subscriber,
|
||||
schema = EntityContexts,
|
||||
backpressure_strategy = "block" # Configurable
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
|
||||
await self.subs.start()
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
q = await self.subs.subscribe_all(self.id)
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_entity_contexts(resp))
|
||||
|
||||
except TimeoutError:
|
||||
consecutive_errors = 0 # Reset on success
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||
consecutive_errors += 1
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("Too many consecutive errors, shutting down")
|
||||
break
|
||||
|
||||
# Brief pause before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Graceful cleanup handled in destroy()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
|||
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EntityContextsImport:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -26,13 +30,17 @@ class EntityContextsImport:
|
|||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
|
|
|||
|
|
@ -26,46 +26,66 @@ class GraphEmbeddingsExport:
|
|||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Signal stop to prevent new messages
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
# Step 2: Wait briefly for in-flight messages
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||
if hasattr(self, 'subs'):
|
||||
await self.subs.unsubscribe_all(self.id)
|
||||
await self.subs.stop()
|
||||
|
||||
# Step 4: Close websocket last
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = GraphEmbeddings
|
||||
"""Enhanced run with better error handling"""
|
||||
self.subs = Subscriber(
|
||||
client = self.pulsar_client,
|
||||
topic = self.queue,
|
||||
consumer_name = self.consumer,
|
||||
subscription = self.subscriber,
|
||||
schema = GraphEmbeddings,
|
||||
backpressure_strategy = "block" # Configurable
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
|
||||
await self.subs.start()
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
q = await self.subs.subscribe_all(self.id)
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
consecutive_errors = 0 # Reset on success
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||
consecutive_errors += 1
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("Too many consecutive errors, shutting down")
|
||||
break
|
||||
|
||||
# Brief pause before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Graceful cleanup handled in destroy()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
|||
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GraphEmbeddingsImport:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -26,13 +30,17 @@ class GraphEmbeddingsImport:
|
|||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
|
|
|||
|
|
@ -26,46 +26,66 @@ class TriplesExport:
|
|||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Signal stop to prevent new messages
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
# Step 2: Wait briefly for in-flight messages
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||
if hasattr(self, 'subs'):
|
||||
await self.subs.unsubscribe_all(self.id)
|
||||
await self.subs.stop()
|
||||
|
||||
# Step 4: Close websocket last
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = Triples
|
||||
"""Enhanced run with better error handling"""
|
||||
self.subs = Subscriber(
|
||||
client = self.pulsar_client,
|
||||
topic = self.queue,
|
||||
consumer_name = self.consumer,
|
||||
subscription = self.subscriber,
|
||||
schema = Triples,
|
||||
backpressure_strategy = "block" # Configurable
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
|
||||
await self.subs.start()
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
q = await self.subs.subscribe_all(self.id)
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_triples(resp))
|
||||
|
||||
except TimeoutError:
|
||||
consecutive_errors = 0 # Reset on success
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||
consecutive_errors += 1
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("Too many consecutive errors, shutting down")
|
||||
break
|
||||
|
||||
# Brief pause before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Graceful cleanup handled in destroy()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
|||
|
||||
from . serialize import to_subgraph
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TriplesImport:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -26,13 +30,17 @@ class TriplesImport:
|
|||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
|
|
|||
|
|
@ -25,24 +25,43 @@ class SocketEndpoint:
|
|||
await dispatcher.run()
|
||||
|
||||
async def listener(self, ws, dispatcher, running):
|
||||
|
||||
async for msg in ws:
|
||||
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
"""Enhanced listener with graceful shutdown"""
|
||||
try:
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
else:
|
||||
# Graceful shutdown on close
|
||||
logger.info("Websocket closing, initiating graceful shutdown")
|
||||
running.stop()
|
||||
|
||||
# Allow time for dispatcher cleanup
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Close websocket if not already closed
|
||||
if not ws.closed:
|
||||
await ws.close()
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
running.stop()
|
||||
await ws.close()
|
||||
# This executes when the async for loop completes normally (no break)
|
||||
logger.debug("Websocket iteration completed, performing cleanup")
|
||||
running.stop()
|
||||
if not ws.closed:
|
||||
await ws.close()
|
||||
except Exception:
|
||||
# Handle exceptions and cleanup
|
||||
running.stop()
|
||||
if not ws.closed:
|
||||
await ws.close()
|
||||
raise
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
"""Enhanced handler with better cleanup"""
|
||||
try:
|
||||
token = request.query['token']
|
||||
except:
|
||||
|
|
@ -55,7 +74,9 @@ class SocketEndpoint:
|
|||
ws = web.WebSocketResponse(max_msg_size=52428800)
|
||||
|
||||
await ws.prepare(request)
|
||||
|
||||
|
||||
dispatcher = None
|
||||
|
||||
try:
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
|
@ -80,9 +101,6 @@ class SocketEndpoint:
|
|||
|
||||
logger.debug("Task group closed")
|
||||
|
||||
# Finally?
|
||||
await dispatcher.destroy()
|
||||
|
||||
except ExceptionGroup as e:
|
||||
|
||||
logger.error("Exception group occurred:", exc_info=True)
|
||||
|
|
@ -90,11 +108,34 @@ class SocketEndpoint:
|
|||
for se in e.exceptions:
|
||||
logger.error(f" Exception type: {type(se)}")
|
||||
logger.error(f" Exception: {se}")
|
||||
|
||||
# Attempt graceful dispatcher shutdown
|
||||
if dispatcher and hasattr(dispatcher, 'destroy'):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
dispatcher.destroy(),
|
||||
timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Dispatcher shutdown timed out")
|
||||
except Exception as de:
|
||||
logger.error(f"Error during dispatcher cleanup: {de}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Socket exception: {e}", exc_info=True)
|
||||
|
||||
await ws.close()
|
||||
|
||||
|
||||
finally:
|
||||
# Ensure dispatcher cleanup
|
||||
if dispatcher and hasattr(dispatcher, 'destroy'):
|
||||
try:
|
||||
await dispatcher.destroy()
|
||||
except Exception as de:
|
||||
logger.error(f"Error in final dispatcher cleanup: {de}")
|
||||
|
||||
# Ensure websocket is closed
|
||||
if ws and not ws.closed:
|
||||
await ws.close()
|
||||
|
||||
return ws
|
||||
|
||||
async def start(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue