mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 15:01:00 +02:00
Graceful shutdown of importers/exporters
This commit is contained in:
parent
eef4c808e7
commit
976529de69
10 changed files with 412 additions and 171 deletions
|
|
@ -12,22 +12,27 @@ logger = logging.getLogger(__name__)
|
||||||
class Publisher:
|
class Publisher:
|
||||||
|
|
||||||
def __init__(self, client, topic, schema=None, max_size=10,
|
def __init__(self, client, topic, schema=None, max_size=10,
|
||||||
chunking_enabled=True):
|
chunking_enabled=True, drain_timeout=5.0):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.q = asyncio.Queue(maxsize=max_size)
|
self.q = asyncio.Queue(maxsize=max_size)
|
||||||
self.chunking_enabled = chunking_enabled
|
self.chunking_enabled = chunking_enabled
|
||||||
self.running = True
|
self.running = True
|
||||||
|
self.draining = False # New state for graceful shutdown
|
||||||
self.task = None
|
self.task = None
|
||||||
|
self.drain_timeout = drain_timeout
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.task = asyncio.create_task(self.run())
|
self.task = asyncio.create_task(self.run())
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
|
"""Initiate graceful shutdown with draining"""
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self.draining = True
|
||||||
|
|
||||||
if self.task:
|
if self.task:
|
||||||
|
# Wait for run() to complete draining
|
||||||
await self.task
|
await self.task
|
||||||
|
|
||||||
async def join(self):
|
async def join(self):
|
||||||
|
|
@ -38,7 +43,7 @@ class Publisher:
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|
||||||
while self.running:
|
while self.running or self.draining:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
@ -48,16 +53,48 @@ class Publisher:
|
||||||
chunking_enabled=self.chunking_enabled,
|
chunking_enabled=self.chunking_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
while self.running:
|
drain_end_time = None
|
||||||
|
|
||||||
|
while self.running or self.draining:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Start drain timeout when entering drain mode
|
||||||
|
if self.draining and drain_end_time is None:
|
||||||
|
drain_end_time = time.time() + self.drain_timeout
|
||||||
|
logger.info(f"Publisher entering drain mode, timeout={self.drain_timeout}s")
|
||||||
|
|
||||||
|
# Check drain timeout
|
||||||
|
if self.draining and drain_end_time and time.time() > drain_end_time:
|
||||||
|
if not self.q.empty():
|
||||||
|
logger.warning(f"Drain timeout reached with {self.q.qsize()} messages remaining")
|
||||||
|
self.draining = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate wait timeout based on mode
|
||||||
|
if self.draining:
|
||||||
|
# Shorter timeout during draining to exit quickly when empty
|
||||||
|
timeout = min(0.1, drain_end_time - time.time()) if drain_end_time else 0.1
|
||||||
|
else:
|
||||||
|
# Normal operation timeout
|
||||||
|
timeout = 0.25
|
||||||
|
|
||||||
id, item = await asyncio.wait_for(
|
id, item = await asyncio.wait_for(
|
||||||
self.q.get(),
|
self.q.get(),
|
||||||
timeout=0.25
|
timeout=timeout
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
# If draining and queue is empty, we're done
|
||||||
|
if self.draining and self.q.empty():
|
||||||
|
logger.info("Publisher queue drained successfully")
|
||||||
|
self.draining = False
|
||||||
|
break
|
||||||
continue
|
continue
|
||||||
except asyncio.QueueEmpty:
|
except asyncio.QueueEmpty:
|
||||||
|
# If draining and queue is empty, we're done
|
||||||
|
if self.draining and self.q.empty():
|
||||||
|
logger.info("Publisher queue drained successfully")
|
||||||
|
self.draining = False
|
||||||
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if id:
|
if id:
|
||||||
|
|
@ -65,15 +102,22 @@ class Publisher:
|
||||||
else:
|
else:
|
||||||
producer.send(item)
|
producer.send(item)
|
||||||
|
|
||||||
|
# Flush producer before closing
|
||||||
|
producer.flush()
|
||||||
|
producer.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception in publisher: {e}", exc_info=True)
|
logger.error(f"Exception in publisher: {e}", exc_info=True)
|
||||||
|
|
||||||
if not self.running:
|
if not self.running and not self.draining:
|
||||||
return
|
return
|
||||||
|
|
||||||
# If handler drops out, sleep a retry
|
# If handler drops out, sleep a retry
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def send(self, id, item):
|
async def send(self, id, item):
|
||||||
|
if self.draining:
|
||||||
|
# Optionally reject new messages during drain
|
||||||
|
raise RuntimeError("Publisher is shutting down, not accepting new messages")
|
||||||
await self.q.put((id, item))
|
await self.q.put((id, item))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import asyncio
|
||||||
import _pulsar
|
import _pulsar
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,7 +16,8 @@ logger = logging.getLogger(__name__)
|
||||||
class Subscriber:
|
class Subscriber:
|
||||||
|
|
||||||
def __init__(self, client, topic, subscription, consumer_name,
|
def __init__(self, client, topic, subscription, consumer_name,
|
||||||
schema=None, max_size=100, metrics=None):
|
schema=None, max_size=100, metrics=None,
|
||||||
|
backpressure_strategy="block", drain_timeout=5.0):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.subscription = subscription
|
self.subscription = subscription
|
||||||
|
|
@ -26,8 +28,12 @@ class Subscriber:
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self.running = True
|
self.running = True
|
||||||
|
self.draining = False # New state for graceful shutdown
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
self.task = None
|
self.task = None
|
||||||
|
self.backpressure_strategy = backpressure_strategy
|
||||||
|
self.drain_timeout = drain_timeout
|
||||||
|
self.pending_acks = {} # Track messages awaiting delivery
|
||||||
|
|
||||||
self.consumer = None
|
self.consumer = None
|
||||||
|
|
||||||
|
|
@ -47,9 +53,12 @@ class Subscriber:
|
||||||
self.task = asyncio.create_task(self.run())
|
self.task = asyncio.create_task(self.run())
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
|
"""Initiate graceful shutdown with draining"""
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self.draining = True
|
||||||
|
|
||||||
if self.task:
|
if self.task:
|
||||||
|
# Wait for run() to complete draining
|
||||||
await self.task
|
await self.task
|
||||||
|
|
||||||
async def join(self):
|
async def join(self):
|
||||||
|
|
@ -59,8 +68,8 @@ class Subscriber:
|
||||||
await self.task
|
await self.task
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
"""Enhanced run method with integrated draining logic"""
|
||||||
while self.running:
|
while self.running or self.draining:
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("stopped")
|
self.metrics.state("stopped")
|
||||||
|
|
@ -71,65 +80,73 @@ class Subscriber:
|
||||||
self.metrics.state("running")
|
self.metrics.state("running")
|
||||||
|
|
||||||
logger.info("Subscriber running...")
|
logger.info("Subscriber running...")
|
||||||
|
drain_end_time = None
|
||||||
|
|
||||||
while self.running:
|
while self.running or self.draining:
|
||||||
|
# Start drain timeout when entering drain mode
|
||||||
|
if self.draining and drain_end_time is None:
|
||||||
|
drain_end_time = time.time() + self.drain_timeout
|
||||||
|
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
|
||||||
|
|
||||||
try:
|
# Stop accepting new messages from Pulsar during drain
|
||||||
msg = await asyncio.to_thread(
|
if self.consumer:
|
||||||
self.consumer.receive,
|
self.consumer.pause_message_listener()
|
||||||
timeout_millis=250
|
|
||||||
)
|
|
||||||
except _pulsar.Timeout:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
if self.metrics:
|
# Check drain timeout
|
||||||
self.metrics.received()
|
if self.draining and drain_end_time and time.time() > drain_end_time:
|
||||||
|
async with self.lock:
|
||||||
|
total_pending = sum(
|
||||||
|
q.qsize() for q in
|
||||||
|
list(self.q.values()) + list(self.full.values())
|
||||||
|
)
|
||||||
|
if total_pending > 0:
|
||||||
|
logger.warning(f"Drain timeout reached with {total_pending} messages in queues")
|
||||||
|
self.draining = False
|
||||||
|
break
|
||||||
|
|
||||||
# Acknowledge successful reception of the message
|
# Check if we can exit drain mode
|
||||||
self.consumer.acknowledge(msg)
|
if self.draining:
|
||||||
|
async with self.lock:
|
||||||
|
all_empty = all(
|
||||||
|
q.empty() for q in
|
||||||
|
list(self.q.values()) + list(self.full.values())
|
||||||
|
)
|
||||||
|
if all_empty and len(self.pending_acks) == 0:
|
||||||
|
logger.info("Subscriber queues drained successfully")
|
||||||
|
self.draining = False
|
||||||
|
break
|
||||||
|
|
||||||
try:
|
# Process messages only if not draining
|
||||||
id = msg.properties()["id"]
|
if not self.draining:
|
||||||
except:
|
try:
|
||||||
id = None
|
msg = await asyncio.to_thread(
|
||||||
|
self.consumer.receive,
|
||||||
|
timeout_millis=250
|
||||||
|
)
|
||||||
|
except _pulsar.Timeout:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
value = msg.value()
|
if self.metrics:
|
||||||
|
self.metrics.received()
|
||||||
|
|
||||||
async with self.lock:
|
# Process the message with deferred acknowledgment
|
||||||
|
await self._process_message(msg)
|
||||||
|
else:
|
||||||
|
# During draining, just wait for queues to empty
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
# FIXME: Hard-coded timeouts
|
|
||||||
|
|
||||||
if id in self.q:
|
|
||||||
|
|
||||||
try:
|
|
||||||
# FIXME: Timeout means data goes missing
|
|
||||||
await asyncio.wait_for(
|
|
||||||
self.q[id].put(value),
|
|
||||||
timeout=1
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.metrics.dropped()
|
|
||||||
logger.warning(f"Failed to put message in queue: {e}")
|
|
||||||
|
|
||||||
for q in self.full.values():
|
|
||||||
try:
|
|
||||||
# FIXME: Timeout means data goes missing
|
|
||||||
await asyncio.wait_for(
|
|
||||||
q.put(value),
|
|
||||||
timeout=1
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.metrics.dropped()
|
|
||||||
logger.warning(f"Failed to put message in full queue: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Subscriber exception: {e}", exc_info=True)
|
logger.error(f"Subscriber exception: {e}", exc_info=True)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
# Negative acknowledge any pending messages
|
||||||
|
for msg in self.pending_acks.values():
|
||||||
|
self.consumer.negative_acknowledge(msg)
|
||||||
|
self.pending_acks.clear()
|
||||||
|
|
||||||
if self.consumer:
|
if self.consumer:
|
||||||
self.consumer.unsubscribe()
|
self.consumer.unsubscribe()
|
||||||
|
|
@ -140,7 +157,7 @@ class Subscriber:
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("stopped")
|
self.metrics.state("stopped")
|
||||||
|
|
||||||
if not self.running:
|
if not self.running and not self.draining:
|
||||||
return
|
return
|
||||||
|
|
||||||
# If handler drops out, sleep a retry
|
# If handler drops out, sleep a retry
|
||||||
|
|
@ -180,3 +197,71 @@ class Subscriber:
|
||||||
# self.full[id].shutdown(immediate=True)
|
# self.full[id].shutdown(immediate=True)
|
||||||
del self.full[id]
|
del self.full[id]
|
||||||
|
|
||||||
|
async def _process_message(self, msg):
|
||||||
|
"""Process a single message with deferred acknowledgment"""
|
||||||
|
# Store message for later acknowledgment
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
self.pending_acks[msg_id] = msg
|
||||||
|
|
||||||
|
try:
|
||||||
|
id = msg.properties()["id"]
|
||||||
|
except:
|
||||||
|
id = None
|
||||||
|
|
||||||
|
value = msg.value()
|
||||||
|
delivery_success = False
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
# Deliver to specific subscribers
|
||||||
|
if id in self.q:
|
||||||
|
delivery_success = await self._deliver_to_queue(
|
||||||
|
self.q[id], value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deliver to all subscribers
|
||||||
|
for q in self.full.values():
|
||||||
|
if await self._deliver_to_queue(q, value):
|
||||||
|
delivery_success = True
|
||||||
|
|
||||||
|
# Acknowledge only on successful delivery
|
||||||
|
if delivery_success:
|
||||||
|
self.consumer.acknowledge(msg)
|
||||||
|
del self.pending_acks[msg_id]
|
||||||
|
else:
|
||||||
|
# Negative acknowledge for retry
|
||||||
|
self.consumer.negative_acknowledge(msg)
|
||||||
|
del self.pending_acks[msg_id]
|
||||||
|
|
||||||
|
async def _deliver_to_queue(self, queue, value):
|
||||||
|
"""Deliver message to queue with backpressure handling"""
|
||||||
|
try:
|
||||||
|
if self.backpressure_strategy == "block":
|
||||||
|
# Block until space available (no timeout)
|
||||||
|
await queue.put(value)
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif self.backpressure_strategy == "drop_oldest":
|
||||||
|
# Drop oldest message if queue full
|
||||||
|
if queue.full():
|
||||||
|
try:
|
||||||
|
queue.get_nowait()
|
||||||
|
if self.metrics:
|
||||||
|
self.metrics.dropped()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
pass
|
||||||
|
await queue.put(value)
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif self.backpressure_strategy == "drop_new":
|
||||||
|
# Drop new message if queue full
|
||||||
|
if queue.full():
|
||||||
|
if self.metrics:
|
||||||
|
self.metrics.dropped()
|
||||||
|
return False
|
||||||
|
await queue.put(value)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to deliver message: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,46 +26,66 @@ class DocumentEmbeddingsExport:
|
||||||
self.subscriber = subscriber
|
self.subscriber = subscriber
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
|
# Step 1: Signal stop to prevent new messages
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
await self.ws.close()
|
|
||||||
|
# Step 2: Wait briefly for in-flight messages
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||||
|
if hasattr(self, 'subs'):
|
||||||
|
await self.subs.unsubscribe_all(self.id)
|
||||||
|
await self.subs.stop()
|
||||||
|
|
||||||
|
# Step 4: Close websocket last
|
||||||
|
if self.ws and not self.ws.closed:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
# Ignore incoming info from websocket
|
# Ignore incoming info from websocket
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
"""Enhanced run with better error handling"""
|
||||||
subs = Subscriber(
|
self.subs = Subscriber(
|
||||||
client = self.pulsar_client, topic = self.queue,
|
client = self.pulsar_client,
|
||||||
consumer_name = self.consumer, subscription = self.subscriber,
|
topic = self.queue,
|
||||||
schema = DocumentEmbeddings
|
consumer_name = self.consumer,
|
||||||
|
subscription = self.subscriber,
|
||||||
|
schema = DocumentEmbeddings,
|
||||||
|
backpressure_strategy = "block" # Configurable
|
||||||
)
|
)
|
||||||
|
|
||||||
await subs.start()
|
await self.subs.start()
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
self.id = str(uuid.uuid4())
|
||||||
q = await subs.subscribe_all(id)
|
q = await self.subs.subscribe_all(self.id)
|
||||||
|
|
||||||
|
consecutive_errors = 0
|
||||||
|
max_consecutive_errors = 5
|
||||||
|
|
||||||
while self.running.get():
|
while self.running.get():
|
||||||
try:
|
try:
|
||||||
|
|
||||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||||
await self.ws.send_json(serialize_document_embeddings(resp))
|
await self.ws.send_json(serialize_document_embeddings(resp))
|
||||||
|
consecutive_errors = 0 # Reset on success
|
||||||
|
|
||||||
except TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||||
break
|
consecutive_errors += 1
|
||||||
|
|
||||||
await subs.unsubscribe_all(id)
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
|
logger.error("Too many consecutive errors, shutting down")
|
||||||
|
break
|
||||||
|
|
||||||
await subs.stop()
|
# Brief pause before retry
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
await self.ws.close()
|
# Graceful cleanup handled in destroy()
|
||||||
self.running.stop()
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
import logging
|
||||||
from aiohttp import WSMsgType
|
from aiohttp import WSMsgType
|
||||||
|
|
||||||
from ... schema import Metadata
|
from ... schema import Metadata
|
||||||
|
|
@ -8,6 +9,9 @@ from ... schema import DocumentEmbeddings, ChunkEmbeddings
|
||||||
from ... base import Publisher
|
from ... base import Publisher
|
||||||
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
|
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
|
||||||
|
|
||||||
|
# Module logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DocumentEmbeddingsImport:
|
class DocumentEmbeddingsImport:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -26,13 +30,17 @@ class DocumentEmbeddingsImport:
|
||||||
await self.publisher.start()
|
await self.publisher.start()
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
|
# Step 1: Stop accepting new messages
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
|
|
||||||
|
# Step 2: Wait for publisher to drain its queue
|
||||||
|
logger.info("Draining publisher queue...")
|
||||||
|
await self.publisher.stop()
|
||||||
|
|
||||||
|
# Step 3: Close websocket only after queue is drained
|
||||||
if self.ws:
|
if self.ws:
|
||||||
await self.ws.close()
|
await self.ws.close()
|
||||||
|
|
||||||
await self.publisher.stop()
|
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
|
|
||||||
data = msg.json()
|
data = msg.json()
|
||||||
|
|
|
||||||
|
|
@ -26,46 +26,66 @@ class EntityContextsExport:
|
||||||
self.subscriber = subscriber
|
self.subscriber = subscriber
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
|
# Step 1: Signal stop to prevent new messages
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
await self.ws.close()
|
|
||||||
|
# Step 2: Wait briefly for in-flight messages
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||||
|
if hasattr(self, 'subs'):
|
||||||
|
await self.subs.unsubscribe_all(self.id)
|
||||||
|
await self.subs.stop()
|
||||||
|
|
||||||
|
# Step 4: Close websocket last
|
||||||
|
if self.ws and not self.ws.closed:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
# Ignore incoming info from websocket
|
# Ignore incoming info from websocket
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
"""Enhanced run with better error handling"""
|
||||||
subs = Subscriber(
|
self.subs = Subscriber(
|
||||||
client = self.pulsar_client, topic = self.queue,
|
client = self.pulsar_client,
|
||||||
consumer_name = self.consumer, subscription = self.subscriber,
|
topic = self.queue,
|
||||||
schema = EntityContexts
|
consumer_name = self.consumer,
|
||||||
|
subscription = self.subscriber,
|
||||||
|
schema = EntityContexts,
|
||||||
|
backpressure_strategy = "block" # Configurable
|
||||||
)
|
)
|
||||||
|
|
||||||
await subs.start()
|
await self.subs.start()
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
self.id = str(uuid.uuid4())
|
||||||
q = await subs.subscribe_all(id)
|
q = await self.subs.subscribe_all(self.id)
|
||||||
|
|
||||||
|
consecutive_errors = 0
|
||||||
|
max_consecutive_errors = 5
|
||||||
|
|
||||||
while self.running.get():
|
while self.running.get():
|
||||||
try:
|
try:
|
||||||
|
|
||||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||||
await self.ws.send_json(serialize_entity_contexts(resp))
|
await self.ws.send_json(serialize_entity_contexts(resp))
|
||||||
|
consecutive_errors = 0 # Reset on success
|
||||||
|
|
||||||
except TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||||
break
|
consecutive_errors += 1
|
||||||
|
|
||||||
await subs.unsubscribe_all(id)
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
|
logger.error("Too many consecutive errors, shutting down")
|
||||||
|
break
|
||||||
|
|
||||||
await subs.stop()
|
# Brief pause before retry
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
await self.ws.close()
|
# Graceful cleanup handled in destroy()
|
||||||
self.running.stop()
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
import logging
|
||||||
from aiohttp import WSMsgType
|
from aiohttp import WSMsgType
|
||||||
|
|
||||||
from ... schema import Metadata
|
from ... schema import Metadata
|
||||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
||||||
|
|
||||||
from . serialize import to_subgraph, to_value
|
from . serialize import to_subgraph, to_value
|
||||||
|
|
||||||
|
# Module logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class EntityContextsImport:
|
class EntityContextsImport:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -26,13 +30,17 @@ class EntityContextsImport:
|
||||||
await self.publisher.start()
|
await self.publisher.start()
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
|
# Step 1: Stop accepting new messages
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
|
|
||||||
|
# Step 2: Wait for publisher to drain its queue
|
||||||
|
logger.info("Draining publisher queue...")
|
||||||
|
await self.publisher.stop()
|
||||||
|
|
||||||
|
# Step 3: Close websocket only after queue is drained
|
||||||
if self.ws:
|
if self.ws:
|
||||||
await self.ws.close()
|
await self.ws.close()
|
||||||
|
|
||||||
await self.publisher.stop()
|
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
|
|
||||||
data = msg.json()
|
data = msg.json()
|
||||||
|
|
|
||||||
|
|
@ -26,46 +26,66 @@ class GraphEmbeddingsExport:
|
||||||
self.subscriber = subscriber
|
self.subscriber = subscriber
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
|
# Step 1: Signal stop to prevent new messages
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
await self.ws.close()
|
|
||||||
|
# Step 2: Wait briefly for in-flight messages
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||||
|
if hasattr(self, 'subs'):
|
||||||
|
await self.subs.unsubscribe_all(self.id)
|
||||||
|
await self.subs.stop()
|
||||||
|
|
||||||
|
# Step 4: Close websocket last
|
||||||
|
if self.ws and not self.ws.closed:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
# Ignore incoming info from websocket
|
# Ignore incoming info from websocket
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
"""Enhanced run with better error handling"""
|
||||||
subs = Subscriber(
|
self.subs = Subscriber(
|
||||||
client = self.pulsar_client, topic = self.queue,
|
client = self.pulsar_client,
|
||||||
consumer_name = self.consumer, subscription = self.subscriber,
|
topic = self.queue,
|
||||||
schema = GraphEmbeddings
|
consumer_name = self.consumer,
|
||||||
|
subscription = self.subscriber,
|
||||||
|
schema = GraphEmbeddings,
|
||||||
|
backpressure_strategy = "block" # Configurable
|
||||||
)
|
)
|
||||||
|
|
||||||
await subs.start()
|
await self.subs.start()
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
self.id = str(uuid.uuid4())
|
||||||
q = await subs.subscribe_all(id)
|
q = await self.subs.subscribe_all(self.id)
|
||||||
|
|
||||||
|
consecutive_errors = 0
|
||||||
|
max_consecutive_errors = 5
|
||||||
|
|
||||||
while self.running.get():
|
while self.running.get():
|
||||||
try:
|
try:
|
||||||
|
|
||||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||||
await self.ws.send_json(serialize_graph_embeddings(resp))
|
await self.ws.send_json(serialize_graph_embeddings(resp))
|
||||||
|
consecutive_errors = 0 # Reset on success
|
||||||
|
|
||||||
except TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||||
break
|
consecutive_errors += 1
|
||||||
|
|
||||||
await subs.unsubscribe_all(id)
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
|
logger.error("Too many consecutive errors, shutting down")
|
||||||
|
break
|
||||||
|
|
||||||
await subs.stop()
|
# Brief pause before retry
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
await self.ws.close()
|
# Graceful cleanup handled in destroy()
|
||||||
self.running.stop()
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
import logging
|
||||||
from aiohttp import WSMsgType
|
from aiohttp import WSMsgType
|
||||||
|
|
||||||
from ... schema import Metadata
|
from ... schema import Metadata
|
||||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
||||||
|
|
||||||
from . serialize import to_subgraph, to_value
|
from . serialize import to_subgraph, to_value
|
||||||
|
|
||||||
|
# Module logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class GraphEmbeddingsImport:
|
class GraphEmbeddingsImport:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -26,13 +30,17 @@ class GraphEmbeddingsImport:
|
||||||
await self.publisher.start()
|
await self.publisher.start()
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
|
# Step 1: Stop accepting new messages
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
|
|
||||||
|
# Step 2: Wait for publisher to drain its queue
|
||||||
|
logger.info("Draining publisher queue...")
|
||||||
|
await self.publisher.stop()
|
||||||
|
|
||||||
|
# Step 3: Close websocket only after queue is drained
|
||||||
if self.ws:
|
if self.ws:
|
||||||
await self.ws.close()
|
await self.ws.close()
|
||||||
|
|
||||||
await self.publisher.stop()
|
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
|
|
||||||
data = msg.json()
|
data = msg.json()
|
||||||
|
|
|
||||||
|
|
@ -26,46 +26,66 @@ class TriplesExport:
|
||||||
self.subscriber = subscriber
|
self.subscriber = subscriber
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
|
# Step 1: Signal stop to prevent new messages
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
await self.ws.close()
|
|
||||||
|
# Step 2: Wait briefly for in-flight messages
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||||
|
if hasattr(self, 'subs'):
|
||||||
|
await self.subs.unsubscribe_all(self.id)
|
||||||
|
await self.subs.stop()
|
||||||
|
|
||||||
|
# Step 4: Close websocket last
|
||||||
|
if self.ws and not self.ws.closed:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
# Ignore incoming info from websocket
|
# Ignore incoming info from websocket
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
"""Enhanced run with better error handling"""
|
||||||
subs = Subscriber(
|
self.subs = Subscriber(
|
||||||
client = self.pulsar_client, topic = self.queue,
|
client = self.pulsar_client,
|
||||||
consumer_name = self.consumer, subscription = self.subscriber,
|
topic = self.queue,
|
||||||
schema = Triples
|
consumer_name = self.consumer,
|
||||||
|
subscription = self.subscriber,
|
||||||
|
schema = Triples,
|
||||||
|
backpressure_strategy = "block" # Configurable
|
||||||
)
|
)
|
||||||
|
|
||||||
await subs.start()
|
await self.subs.start()
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
self.id = str(uuid.uuid4())
|
||||||
q = await subs.subscribe_all(id)
|
q = await self.subs.subscribe_all(self.id)
|
||||||
|
|
||||||
|
consecutive_errors = 0
|
||||||
|
max_consecutive_errors = 5
|
||||||
|
|
||||||
while self.running.get():
|
while self.running.get():
|
||||||
try:
|
try:
|
||||||
|
|
||||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||||
await self.ws.send_json(serialize_triples(resp))
|
await self.ws.send_json(serialize_triples(resp))
|
||||||
|
consecutive_errors = 0 # Reset on success
|
||||||
|
|
||||||
except TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||||
break
|
consecutive_errors += 1
|
||||||
|
|
||||||
await subs.unsubscribe_all(id)
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
|
logger.error("Too many consecutive errors, shutting down")
|
||||||
|
break
|
||||||
|
|
||||||
await subs.stop()
|
# Brief pause before retry
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
await self.ws.close()
|
# Graceful cleanup handled in destroy()
|
||||||
self.running.stop()
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
import logging
|
||||||
from aiohttp import WSMsgType
|
from aiohttp import WSMsgType
|
||||||
|
|
||||||
from ... schema import Metadata
|
from ... schema import Metadata
|
||||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
||||||
|
|
||||||
from . serialize import to_subgraph
|
from . serialize import to_subgraph
|
||||||
|
|
||||||
|
# Module logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TriplesImport:
|
class TriplesImport:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -26,13 +30,17 @@ class TriplesImport:
|
||||||
await self.publisher.start()
|
await self.publisher.start()
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
|
# Step 1: Stop accepting new messages
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
|
|
||||||
|
# Step 2: Wait for publisher to drain its queue
|
||||||
|
logger.info("Draining publisher queue...")
|
||||||
|
await self.publisher.stop()
|
||||||
|
|
||||||
|
# Step 3: Close websocket only after queue is drained
|
||||||
if self.ws:
|
if self.ws:
|
||||||
await self.ws.close()
|
await self.ws.close()
|
||||||
|
|
||||||
await self.publisher.stop()
|
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
|
|
||||||
data = msg.json()
|
data = msg.json()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue