Fix import export graceful shutdown (#476)

* Tech spec for graceful shutdown

* Graceful shutdown of importers/exporters

* Update socket to include graceful shutdown orchestration

* Adding tests for conditions tracked in this PR
This commit is contained in:
cybermaggedon 2025-08-28 13:39:28 +01:00 committed by GitHub
parent 4361e8ccca
commit 96c2b73457
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2668 additions and 193 deletions

View file

@ -12,22 +12,27 @@ logger = logging.getLogger(__name__)
class Publisher:
def __init__(self, client, topic, schema=None, max_size=10,
chunking_enabled=True):
chunking_enabled=True, drain_timeout=5.0):
self.client = client
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
self.running = True
self.draining = False # New state for graceful shutdown
self.task = None
self.drain_timeout = drain_timeout
async def start(self):
self.task = asyncio.create_task(self.run())
async def stop(self):
"""Initiate graceful shutdown with draining"""
self.running = False
self.draining = True
if self.task:
# Wait for run() to complete draining
await self.task
async def join(self):
@ -38,7 +43,7 @@ class Publisher:
async def run(self):
while self.running:
while self.running or self.draining:
try:
@ -48,32 +53,71 @@ class Publisher:
chunking_enabled=self.chunking_enabled,
)
while self.running:
drain_end_time = None
while self.running or self.draining:
try:
# Start drain timeout when entering drain mode
if self.draining and drain_end_time is None:
drain_end_time = time.time() + self.drain_timeout
logger.info(f"Publisher entering drain mode, timeout={self.drain_timeout}s")
# Check drain timeout
if self.draining and drain_end_time and time.time() > drain_end_time:
if not self.q.empty():
logger.warning(f"Drain timeout reached with {self.q.qsize()} messages remaining")
self.draining = False
break
# Calculate wait timeout based on mode
if self.draining:
# Shorter timeout during draining to exit quickly when empty
timeout = min(0.1, drain_end_time - time.time()) if drain_end_time else 0.1
else:
# Normal operation timeout
timeout = 0.25
id, item = await asyncio.wait_for(
self.q.get(),
timeout=0.25
timeout=timeout
)
except asyncio.TimeoutError:
# If draining and queue is empty, we're done
if self.draining and self.q.empty():
logger.info("Publisher queue drained successfully")
self.draining = False
break
continue
except asyncio.QueueEmpty:
# If draining and queue is empty, we're done
if self.draining and self.q.empty():
logger.info("Publisher queue drained successfully")
self.draining = False
break
continue
if id:
producer.send(item, { "id": id })
else:
producer.send(item)
# Flush producer before closing
producer.flush()
producer.close()
except Exception as e:
logger.error(f"Exception in publisher: {e}", exc_info=True)
if not self.running:
if not self.running and not self.draining:
return
# If handler drops out, sleep a retry
await asyncio.sleep(1)
async def send(self, id, item):
if self.draining:
# Optionally reject new messages during drain
raise RuntimeError("Publisher is shutting down, not accepting new messages")
await self.q.put((id, item))

View file

@ -8,6 +8,7 @@ import asyncio
import _pulsar
import time
import logging
import uuid
# Module logger
logger = logging.getLogger(__name__)
@ -15,7 +16,8 @@ logger = logging.getLogger(__name__)
class Subscriber:
def __init__(self, client, topic, subscription, consumer_name,
schema=None, max_size=100, metrics=None):
schema=None, max_size=100, metrics=None,
backpressure_strategy="block", drain_timeout=5.0):
self.client = client
self.topic = topic
self.subscription = subscription
@ -26,8 +28,12 @@ class Subscriber:
self.max_size = max_size
self.lock = asyncio.Lock()
self.running = True
self.draining = False # New state for graceful shutdown
self.metrics = metrics
self.task = None
self.backpressure_strategy = backpressure_strategy
self.drain_timeout = drain_timeout
self.pending_acks = {} # Track messages awaiting delivery
self.consumer = None
@ -47,9 +53,12 @@ class Subscriber:
self.task = asyncio.create_task(self.run())
async def stop(self):
"""Initiate graceful shutdown with draining"""
self.running = False
self.draining = True
if self.task:
# Wait for run() to complete draining
await self.task
async def join(self):
@ -59,8 +68,8 @@ class Subscriber:
await self.task
async def run(self):
while self.running:
"""Enhanced run method with integrated draining logic"""
while self.running or self.draining:
if self.metrics:
self.metrics.state("stopped")
@ -71,65 +80,73 @@ class Subscriber:
self.metrics.state("running")
logger.info("Subscriber running...")
drain_end_time = None
while self.running:
while self.running or self.draining:
# Start drain timeout when entering drain mode
if self.draining and drain_end_time is None:
drain_end_time = time.time() + self.drain_timeout
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
# Stop accepting new messages from Pulsar during drain
if self.consumer:
self.consumer.pause_message_listener()
# Check drain timeout
if self.draining and drain_end_time and time.time() > drain_end_time:
async with self.lock:
total_pending = sum(
q.qsize() for q in
list(self.q.values()) + list(self.full.values())
)
if total_pending > 0:
logger.warning(f"Drain timeout reached with {total_pending} messages in queues")
self.draining = False
break
# Check if we can exit drain mode
if self.draining:
async with self.lock:
all_empty = all(
q.empty() for q in
list(self.q.values()) + list(self.full.values())
)
if all_empty and len(self.pending_acks) == 0:
logger.info("Subscriber queues drained successfully")
self.draining = False
break
# Process messages only if not draining
if not self.draining:
try:
msg = await asyncio.to_thread(
self.consumer.receive,
timeout_millis=250
)
except _pulsar.Timeout:
continue
except Exception as e:
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
raise e
try:
msg = await asyncio.to_thread(
self.consumer.receive,
timeout_millis=250
)
except _pulsar.Timeout:
continue
except Exception as e:
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
raise e
if self.metrics:
self.metrics.received()
if self.metrics:
self.metrics.received()
# Process the message with deferred acknowledgment
await self._process_message(msg)
else:
# During draining, just wait for queues to empty
await asyncio.sleep(0.1)
# Acknowledge successful reception of the message
self.consumer.acknowledge(msg)
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
async with self.lock:
# FIXME: Hard-coded timeouts
if id in self.q:
try:
# FIXME: Timeout means data goes missing
await asyncio.wait_for(
self.q[id].put(value),
timeout=1
)
except Exception as e:
self.metrics.dropped()
logger.warning(f"Failed to put message in queue: {e}")
for q in self.full.values():
try:
# FIXME: Timeout means data goes missing
await asyncio.wait_for(
q.put(value),
timeout=1
)
except Exception as e:
self.metrics.dropped()
logger.warning(f"Failed to put message in full queue: {e}")
except Exception as e:
logger.error(f"Subscriber exception: {e}", exc_info=True)
finally:
# Negative acknowledge any pending messages
for msg in self.pending_acks.values():
self.consumer.negative_acknowledge(msg)
self.pending_acks.clear()
if self.consumer:
self.consumer.unsubscribe()
@ -140,7 +157,7 @@ class Subscriber:
if self.metrics:
self.metrics.state("stopped")
if not self.running:
if not self.running and not self.draining:
return
# If handler drops out, sleep a retry
@ -180,3 +197,71 @@ class Subscriber:
# self.full[id].shutdown(immediate=True)
del self.full[id]
async def _process_message(self, msg):
"""Process a single message with deferred acknowledgment"""
# Store message for later acknowledgment
msg_id = str(uuid.uuid4())
self.pending_acks[msg_id] = msg
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
delivery_success = False
async with self.lock:
# Deliver to specific subscribers
if id in self.q:
delivery_success = await self._deliver_to_queue(
self.q[id], value
)
# Deliver to all subscribers
for q in self.full.values():
if await self._deliver_to_queue(q, value):
delivery_success = True
# Acknowledge only on successful delivery
if delivery_success:
self.consumer.acknowledge(msg)
del self.pending_acks[msg_id]
else:
# Negative acknowledge for retry
self.consumer.negative_acknowledge(msg)
del self.pending_acks[msg_id]
async def _deliver_to_queue(self, queue, value):
"""Deliver message to queue with backpressure handling"""
try:
if self.backpressure_strategy == "block":
# Block until space available (no timeout)
await queue.put(value)
return True
elif self.backpressure_strategy == "drop_oldest":
# Drop oldest message if queue full
if queue.full():
try:
queue.get_nowait()
if self.metrics:
self.metrics.dropped()
except asyncio.QueueEmpty:
pass
await queue.put(value)
return True
elif self.backpressure_strategy == "drop_new":
# Drop new message if queue full
if queue.full():
if self.metrics:
self.metrics.dropped()
return False
await queue.put(value)
return True
except Exception as e:
logger.error(f"Failed to deliver message: {e}")
return False