diff --git a/trustgraph-base/trustgraph/base/publisher.py b/trustgraph-base/trustgraph/base/publisher.py index bad7791f..5a481f82 100644 --- a/trustgraph-base/trustgraph/base/publisher.py +++ b/trustgraph-base/trustgraph/base/publisher.py @@ -12,22 +12,27 @@ logger = logging.getLogger(__name__) class Publisher: def __init__(self, client, topic, schema=None, max_size=10, - chunking_enabled=True): + chunking_enabled=True, drain_timeout=5.0): self.client = client self.topic = topic self.schema = schema self.q = asyncio.Queue(maxsize=max_size) self.chunking_enabled = chunking_enabled self.running = True + self.draining = False # New state for graceful shutdown self.task = None + self.drain_timeout = drain_timeout async def start(self): self.task = asyncio.create_task(self.run()) async def stop(self): + """Initiate graceful shutdown with draining""" self.running = False + self.draining = True if self.task: + # Wait for run() to complete draining await self.task async def join(self): @@ -38,7 +43,7 @@ class Publisher: async def run(self): - while self.running: + while self.running or self.draining: try: @@ -48,32 +53,71 @@ class Publisher: chunking_enabled=self.chunking_enabled, ) - while self.running: + drain_end_time = None + + while self.running or self.draining: try: + # Start drain timeout when entering drain mode + if self.draining and drain_end_time is None: + drain_end_time = time.time() + self.drain_timeout + logger.info(f"Publisher entering drain mode, timeout={self.drain_timeout}s") + + # Check drain timeout + if self.draining and drain_end_time and time.time() > drain_end_time: + if not self.q.empty(): + logger.warning(f"Drain timeout reached with {self.q.qsize()} messages remaining") + self.draining = False + break + + # Calculate wait timeout based on mode + if self.draining: + # Shorter timeout during draining to exit quickly when empty + timeout = min(0.1, drain_end_time - time.time()) if drain_end_time else 0.1 + else: + # Normal operation timeout + timeout = 0.25 + id, item = await asyncio.wait_for( self.q.get(), - timeout=0.25 + timeout=timeout ) except asyncio.TimeoutError: + # If draining and queue is empty, we're done + if self.draining and self.q.empty(): + logger.info("Publisher queue drained successfully") + self.draining = False + break continue except asyncio.QueueEmpty: + # If draining and queue is empty, we're done + if self.draining and self.q.empty(): + logger.info("Publisher queue drained successfully") + self.draining = False + break continue if id: producer.send(item, { "id": id }) else: producer.send(item) + + # Flush producer before closing + producer.flush() + producer.close() except Exception as e: logger.error(f"Exception in publisher: {e}", exc_info=True) - if not self.running: + if not self.running and not self.draining: return # If handler drops out, sleep a retry await asyncio.sleep(1) async def send(self, id, item): + if self.draining: + # Optionally reject new messages during drain + raise RuntimeError("Publisher is shutting down, not accepting new messages") await self.q.put((id, item)) diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index 7b5fa6b5..24b7a45c 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -8,6 +8,7 @@ import asyncio import _pulsar import time import logging +import uuid # Module logger logger = logging.getLogger(__name__) @@ -15,7 +16,8 @@ logger = logging.getLogger(__name__) class Subscriber: def __init__(self, client, topic, subscription, consumer_name, - schema=None, max_size=100, metrics=None): + schema=None, max_size=100, metrics=None, + backpressure_strategy="block", drain_timeout=5.0): self.client = client self.topic = topic self.subscription = subscription @@ -26,8 +28,12 @@ class Subscriber: self.max_size = max_size self.lock = asyncio.Lock() self.running = True + self.draining = False # New state for graceful shutdown self.metrics = metrics self.task = None + self.backpressure_strategy = backpressure_strategy + self.drain_timeout = drain_timeout + self.pending_acks = {} # Track messages awaiting delivery self.consumer = None @@ -47,9 +53,12 @@ class Subscriber: self.task = asyncio.create_task(self.run()) async def stop(self): + """Initiate graceful shutdown with draining""" self.running = False + self.draining = True if self.task: + # Wait for run() to complete draining await self.task async def join(self): @@ -59,8 +68,8 @@ class Subscriber: await self.task async def run(self): - - while self.running: + """Enhanced run method with integrated draining logic""" + while self.running or self.draining: if self.metrics: self.metrics.state("stopped") @@ -71,65 +80,73 @@ class Subscriber: self.metrics.state("running") logger.info("Subscriber running...") + drain_end_time = None - while self.running: + while self.running or self.draining: + # Start drain timeout when entering drain mode + if self.draining and drain_end_time is None: + drain_end_time = time.time() + self.drain_timeout + logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s") + + # Stop accepting new messages from Pulsar during drain + if self.consumer: + self.consumer.pause_message_listener() + + # Check drain timeout + if self.draining and drain_end_time and time.time() > drain_end_time: + async with self.lock: + total_pending = sum( + q.qsize() for q in + list(self.q.values()) + list(self.full.values()) + ) + if total_pending > 0: + logger.warning(f"Drain timeout reached with {total_pending} messages in queues") + self.draining = False + break + + # Check if we can exit drain mode + if self.draining: + async with self.lock: + all_empty = all( + q.empty() for q in + list(self.q.values()) + list(self.full.values()) + ) + if all_empty and len(self.pending_acks) == 0: + logger.info("Subscriber queues drained successfully") + self.draining = False + break + + # Process messages only if not draining + if not self.draining: + try: + msg = await asyncio.to_thread( + self.consumer.receive, + timeout_millis=250 + ) + except _pulsar.Timeout: + continue + except Exception as e: + logger.error(f"Exception in subscriber receive: {e}", exc_info=True) + raise e - try: - msg = await asyncio.to_thread( - self.consumer.receive, - timeout_millis=250 - ) - except _pulsar.Timeout: - continue - except Exception as e: - logger.error(f"Exception in subscriber receive: {e}", exc_info=True) - raise e + if self.metrics: + self.metrics.received() - if self.metrics: - self.metrics.received() + # Process the message with deferred acknowledgment + await self._process_message(msg) + else: + # During draining, just wait for queues to empty + await asyncio.sleep(0.1) - # Acknowledge successful reception of the message - self.consumer.acknowledge(msg) - - try: - id = msg.properties()["id"] - except: - id = None - - value = msg.value() - - async with self.lock: - - # FIXME: Hard-coded timeouts - - if id in self.q: - - try: - # FIXME: Timeout means data goes missing - await asyncio.wait_for( - self.q[id].put(value), - timeout=1 - ) - - except Exception as e: - self.metrics.dropped() - logger.warning(f"Failed to put message in queue: {e}") - - for q in self.full.values(): - try: - # FIXME: Timeout means data goes missing - await asyncio.wait_for( - q.put(value), - timeout=1 - ) - except Exception as e: - self.metrics.dropped() - logger.warning(f"Failed to put message in full queue: {e}") except Exception as e: logger.error(f"Subscriber exception: {e}", exc_info=True) finally: + # Negative acknowledge any pending messages + for msg in self.pending_acks.values(): + self.consumer.negative_acknowledge(msg) + self.pending_acks.clear() if self.consumer: self.consumer.unsubscribe() @@ -140,7 +157,7 @@ class Subscriber: if self.metrics: self.metrics.state("stopped") - if not self.running: + if not self.running and not self.draining: return # If handler drops out, sleep a retry @@ -180,3 +197,71 @@ class Subscriber: # self.full[id].shutdown(immediate=True) del self.full[id] + async def _process_message(self, msg): + """Process a single message with deferred acknowledgment""" + # Store message for later acknowledgment + msg_id = str(uuid.uuid4()) + self.pending_acks[msg_id] = msg + + try: + id = msg.properties()["id"] + except: + id = None + + value = msg.value() + delivery_success = False + + async with self.lock: + # Deliver to specific subscribers + if id in self.q: + delivery_success = await self._deliver_to_queue( + self.q[id], value + ) + + # Deliver to all subscribers + for q in self.full.values(): + if await self._deliver_to_queue(q, value): + delivery_success = True + + # Acknowledge only on successful delivery + if delivery_success: + self.consumer.acknowledge(msg) + del self.pending_acks[msg_id] + else: + # Negative acknowledge for retry + self.consumer.negative_acknowledge(msg) + del self.pending_acks[msg_id] + + async def _deliver_to_queue(self, queue, value): + """Deliver message to queue with backpressure handling""" + try: + if self.backpressure_strategy == "block": + # Block until space available (no timeout) + await queue.put(value) + return True + + elif self.backpressure_strategy == "drop_oldest": + # Drop oldest message if queue full + if queue.full(): + try: + queue.get_nowait() + if self.metrics: + self.metrics.dropped() + except asyncio.QueueEmpty: + pass + await queue.put(value) + return True + + elif self.backpressure_strategy == "drop_new": + # Drop new message if queue full + if queue.full(): + if self.metrics: + self.metrics.dropped() + return False + await queue.put(value) + return True + + except Exception as e: + logger.error(f"Failed to deliver message: {e}") + return False + diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py index 1c65e8b3..f7d53005 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py @@ -26,46 +26,66 @@ class DocumentEmbeddingsExport: self.subscriber = subscriber async def destroy(self): + # Step 1: Signal stop to prevent new messages self.running.stop() - await self.ws.close() + + # Step 2: Wait briefly for in-flight messages + await asyncio.sleep(0.5) + + # Step 3: Unsubscribe and stop subscriber (triggers queue drain) + if hasattr(self, 'subs'): + await self.subs.unsubscribe_all(self.id) + await self.subs.stop() + + # Step 4: Close websocket last + if self.ws and not self.ws.closed: + await self.ws.close() async def receive(self, msg): # Ignore incoming info from websocket pass async def run(self): - - subs = Subscriber( - client = self.pulsar_client, topic = self.queue, - consumer_name = self.consumer, subscription = self.subscriber, - schema = DocumentEmbeddings + """Enhanced run with better error handling""" + self.subs = Subscriber( + client = self.pulsar_client, + topic = self.queue, + consumer_name = self.consumer, + subscription = self.subscriber, + schema = DocumentEmbeddings, + backpressure_strategy = "block" # Configurable ) - - await subs.start() - - id = str(uuid.uuid4()) - q = await subs.subscribe_all(id) - + + await self.subs.start() + + self.id = str(uuid.uuid4()) + q = await self.subs.subscribe_all(self.id) + + consecutive_errors = 0 + max_consecutive_errors = 5 + while self.running.get(): try: - resp = await asyncio.wait_for(q.get(), timeout=0.5) await self.ws.send_json(serialize_document_embeddings(resp)) - - except TimeoutError: + consecutive_errors = 0 # Reset on success + + except asyncio.TimeoutError: continue - + except queue.Empty: continue - + except Exception as e: - logger.error(f"Exception: {str(e)}", exc_info=True) - break - - await subs.unsubscribe_all(id) - - await subs.stop() - - await self.ws.close() - self.running.stop() + logger.error(f"Exception sending to websocket: {str(e)}") + consecutive_errors += 1 + + if consecutive_errors >= max_consecutive_errors: + logger.error("Too many consecutive errors, shutting down") + break + + # Brief pause before retry + await asyncio.sleep(0.1) + + # Graceful cleanup handled in destroy() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py index dd4fc4e1..7ec2f595 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py @@ -1,6 +1,7 @@ import asyncio import uuid +import logging from aiohttp import WSMsgType from ... schema import Metadata @@ -8,6 +9,9 @@ from ... schema import DocumentEmbeddings, ChunkEmbeddings from ... base import Publisher from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator +# Module logger +logger = logging.getLogger(__name__) + class DocumentEmbeddingsImport: def __init__( @@ -26,13 +30,17 @@ class DocumentEmbeddingsImport: await self.publisher.start() async def destroy(self): + # Step 1: Stop accepting new messages self.running.stop() + # Step 2: Wait for publisher to drain its queue + logger.info("Draining publisher queue...") + await self.publisher.stop() + + # Step 3: Close websocket only after queue is drained if self.ws: await self.ws.close() - await self.publisher.stop() - async def receive(self, msg): data = msg.json() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py index 9585c1d0..2be9c703 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py @@ -26,46 +26,66 @@ class EntityContextsExport: self.subscriber = subscriber async def destroy(self): + # Step 1: Signal stop to prevent new messages self.running.stop() - await self.ws.close() + + # Step 2: Wait briefly for in-flight messages + await asyncio.sleep(0.5) + + # Step 3: Unsubscribe and stop subscriber (triggers queue drain) + if hasattr(self, 'subs'): + await self.subs.unsubscribe_all(self.id) + await self.subs.stop() + + # Step 4: Close websocket last + if self.ws and not self.ws.closed: + await self.ws.close() async def receive(self, msg): # Ignore incoming info from websocket pass async def run(self): - - subs = Subscriber( - client = self.pulsar_client, topic = self.queue, - consumer_name = self.consumer, subscription = self.subscriber, - schema = EntityContexts + """Enhanced run with better error handling""" + self.subs = Subscriber( + client = self.pulsar_client, + topic = self.queue, + consumer_name = self.consumer, + subscription = self.subscriber, + schema = EntityContexts, + backpressure_strategy = "block" # Configurable ) - - await subs.start() - - id = str(uuid.uuid4()) - q = await subs.subscribe_all(id) - + + await self.subs.start() + + self.id = str(uuid.uuid4()) + q = await self.subs.subscribe_all(self.id) + + consecutive_errors = 0 + max_consecutive_errors = 5 + while self.running.get(): try: - resp = await asyncio.wait_for(q.get(), timeout=0.5) await self.ws.send_json(serialize_entity_contexts(resp)) - - except TimeoutError: + consecutive_errors = 0 # Reset on success + + except asyncio.TimeoutError: continue - + except queue.Empty: continue - + except Exception as e: - logger.error(f"Exception: {str(e)}", exc_info=True) - break - - await subs.unsubscribe_all(id) - - await subs.stop() - - await self.ws.close() - self.running.stop() + logger.error(f"Exception sending to websocket: {str(e)}") + consecutive_errors += 1 + + if consecutive_errors >= max_consecutive_errors: + logger.error("Too many consecutive errors, shutting down") + break + + # Brief pause before retry + await asyncio.sleep(0.1) + + # Graceful cleanup handled in destroy() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py index 22d18904..c76f1612 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py @@ -1,6 +1,7 @@ import asyncio import uuid +import logging from aiohttp import WSMsgType from ... schema import Metadata @@ -9,6 +10,9 @@ from ... base import Publisher from . serialize import to_subgraph, to_value +# Module logger +logger = logging.getLogger(__name__) + class EntityContextsImport: def __init__( @@ -26,13 +30,17 @@ class EntityContextsImport: await self.publisher.start() async def destroy(self): + # Step 1: Stop accepting new messages self.running.stop() + # Step 2: Wait for publisher to drain its queue + logger.info("Draining publisher queue...") + await self.publisher.stop() + + # Step 3: Close websocket only after queue is drained if self.ws: await self.ws.close() - await self.publisher.stop() - async def receive(self, msg): data = msg.json() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_export.py index 44c70dfd..d4abec73 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_export.py @@ -26,46 +26,66 @@ class GraphEmbeddingsExport: self.subscriber = subscriber async def destroy(self): + # Step 1: Signal stop to prevent new messages self.running.stop() - await self.ws.close() + + # Step 2: Wait briefly for in-flight messages + await asyncio.sleep(0.5) + + # Step 3: Unsubscribe and stop subscriber (triggers queue drain) + if hasattr(self, 'subs'): + await self.subs.unsubscribe_all(self.id) + await self.subs.stop() + + # Step 4: Close websocket last + if self.ws and not self.ws.closed: + await self.ws.close() async def receive(self, msg): # Ignore incoming info from websocket pass async def run(self): - - subs = Subscriber( - client = self.pulsar_client, topic = self.queue, - consumer_name = self.consumer, subscription = self.subscriber, - schema = GraphEmbeddings + """Enhanced run with better error handling""" + self.subs = Subscriber( + client = self.pulsar_client, + topic = self.queue, + consumer_name = self.consumer, + subscription = self.subscriber, + schema = GraphEmbeddings, + backpressure_strategy = "block" # Configurable ) - - await subs.start() - - id = str(uuid.uuid4()) - q = await subs.subscribe_all(id) - + + await self.subs.start() + + self.id = str(uuid.uuid4()) + q = await self.subs.subscribe_all(self.id) + + consecutive_errors = 0 + max_consecutive_errors = 5 + while self.running.get(): try: - resp = await asyncio.wait_for(q.get(), timeout=0.5) await self.ws.send_json(serialize_graph_embeddings(resp)) - - except TimeoutError: + consecutive_errors = 0 # Reset on success + + except asyncio.TimeoutError: continue - + except queue.Empty: continue - + except Exception as e: - logger.error(f"Exception: {str(e)}", exc_info=True) - break - - await subs.unsubscribe_all(id) - - await subs.stop() - - await self.ws.close() - self.running.stop() + logger.error(f"Exception sending to websocket: {str(e)}") + consecutive_errors += 1 + + if consecutive_errors >= max_consecutive_errors: + logger.error("Too many consecutive errors, shutting down") + break + + # Brief pause before retry + await asyncio.sleep(0.1) + + # Graceful cleanup handled in destroy() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py index 85174460..ee3d88ef 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py @@ -1,6 +1,7 @@ import asyncio import uuid +import logging from aiohttp import WSMsgType from ... schema import Metadata @@ -9,6 +10,9 @@ from ... base import Publisher from . serialize import to_subgraph, to_value +# Module logger +logger = logging.getLogger(__name__) + class GraphEmbeddingsImport: def __init__( @@ -26,13 +30,17 @@ class GraphEmbeddingsImport: await self.publisher.start() async def destroy(self): + # Step 1: Stop accepting new messages self.running.stop() + # Step 2: Wait for publisher to drain its queue + logger.info("Draining publisher queue...") + await self.publisher.stop() + + # Step 3: Close websocket only after queue is drained if self.ws: await self.ws.close() - await self.publisher.stop() - async def receive(self, msg): data = msg.json() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_export.py index 2847c182..ff91e461 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_export.py @@ -26,46 +26,66 @@ class TriplesExport: self.subscriber = subscriber async def destroy(self): + # Step 1: Signal stop to prevent new messages self.running.stop() - await self.ws.close() + + # Step 2: Wait briefly for in-flight messages + await asyncio.sleep(0.5) + + # Step 3: Unsubscribe and stop subscriber (triggers queue drain) + if hasattr(self, 'subs'): + await self.subs.unsubscribe_all(self.id) + await self.subs.stop() + + # Step 4: Close websocket last + if self.ws and not self.ws.closed: + await self.ws.close() async def receive(self, msg): # Ignore incoming info from websocket pass async def run(self): - - subs = Subscriber( - client = self.pulsar_client, topic = self.queue, - consumer_name = self.consumer, subscription = self.subscriber, - schema = Triples + """Enhanced run with better error handling""" + self.subs = Subscriber( + client = self.pulsar_client, + topic = self.queue, + consumer_name = self.consumer, + subscription = self.subscriber, + schema = Triples, + backpressure_strategy = "block" # Configurable ) - - await subs.start() - - id = str(uuid.uuid4()) - q = await subs.subscribe_all(id) - + + await self.subs.start() + + self.id = str(uuid.uuid4()) + q = await self.subs.subscribe_all(self.id) + + consecutive_errors = 0 + max_consecutive_errors = 5 + while self.running.get(): try: - resp = await asyncio.wait_for(q.get(), timeout=0.5) await self.ws.send_json(serialize_triples(resp)) - - except TimeoutError: + consecutive_errors = 0 # Reset on success + + except asyncio.TimeoutError: continue - + except queue.Empty: continue - + except Exception as e: - logger.error(f"Exception: {str(e)}", exc_info=True) - break - - await subs.unsubscribe_all(id) - - await subs.stop() - - await self.ws.close() - self.running.stop() + logger.error(f"Exception sending to websocket: {str(e)}") + consecutive_errors += 1 + + if consecutive_errors >= max_consecutive_errors: + logger.error("Too many consecutive errors, shutting down") + break + + # Brief pause before retry + await asyncio.sleep(0.1) + + # Graceful cleanup handled in destroy() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py index 687b424a..520a9cbc 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py @@ -1,6 +1,7 @@ import asyncio import uuid +import logging from aiohttp import WSMsgType from ... schema import Metadata @@ -9,6 +10,9 @@ from ... base import Publisher from . serialize import to_subgraph +# Module logger +logger = logging.getLogger(__name__) + class TriplesImport: def __init__( @@ -26,13 +30,17 @@ class TriplesImport: await self.publisher.start() async def destroy(self): + # Step 1: Stop accepting new messages self.running.stop() + # Step 2: Wait for publisher to drain its queue + logger.info("Draining publisher queue...") + await self.publisher.stop() + + # Step 3: Close websocket only after queue is drained if self.ws: await self.ws.close() - await self.publisher.stop() - async def receive(self, msg): data = msg.json()