trustgraph/trustgraph-base/trustgraph/base/subscriber.py
cybermaggedon c20e6540ec
Subscriber resilience and RabbitMQ fixes (#765)
Subscriber resilience: recreate consumer after connection failure

- Move consumer creation from Subscriber.start() into the run() loop,
  matching the pattern used by Consumer. If the connection drops and the
  consumer is closed in the finally block, the loop now recreates it on
  the next iteration instead of spinning forever on a None consumer.

Consumer thread safety:
- Dedicated ThreadPoolExecutor per consumer so all pika operations
  (create, receive, acknowledge, negative_acknowledge) run on the
  same thread — pika BlockingConnection is not thread-safe
- Applies to both Consumer and Subscriber classes

Config handler type audit — fix four mismatched type registrations:
- librarian: was ["librarian"] (non-existent type), now ["flow",
  "active-flow"] (matches config["flow"] that the handler reads)
- cores/service: was ["kg-core"], now ["flow"] (reads
  config["flow"])
- metering/counter: was ["token-costs"], now ["token-cost"]
  (singular)
- agent/mcp_tool: was ["mcp-tool"], now ["mcp"] (reads
  config["mcp"])

Update tests
2026-04-07 14:51:14 +01:00

316 lines
11 KiB
Python

# Subscriber is similar to consumer: It provides a service to take stuff
# off of a queue and make it available using an internal broker system,
# so suitable for when multiple recipients are reading from the same queue
import asyncio
import time
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
# Module logger
logger = logging.getLogger(__name__)
# Timeout exception - can come from different backends
class TimeoutError(Exception):
pass
class Subscriber:
def __init__(self, backend, topic, subscription, consumer_name,
schema=None, max_size=100, metrics=None,
backpressure_strategy="block", drain_timeout=5.0):
self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
self.full = {}
self.max_size = max_size
self.lock = asyncio.Lock()
self.running = True
self.draining = False # New state for graceful shutdown
self.metrics = metrics
self.task = None
self.backpressure_strategy = backpressure_strategy
self.drain_timeout = drain_timeout
self.pending_acks = {} # Track messages awaiting delivery
self.consumer = None
self.executor = None
def __del__(self):
self.running = False
async def start(self):
self.task = asyncio.create_task(self.run())
async def stop(self):
"""Initiate graceful shutdown with draining"""
self.running = False
self.draining = True
if self.task:
# Wait for run() to complete draining
await self.task
async def join(self):
await self.stop()
if self.task:
await self.task
async def run(self):
"""Enhanced run method with integrated draining logic"""
while self.running or self.draining:
if self.metrics:
self.metrics.state("stopped")
try:
# Create consumer and dedicated thread if needed
# (first run or after failure)
if self.consumer is None:
self.executor = ThreadPoolExecutor(max_workers=1)
loop = asyncio.get_event_loop()
self.consumer = await loop.run_in_executor(
self.executor,
lambda: self.backend.create_consumer(
topic=self.topic,
subscription=self.subscription,
schema=self.schema,
consumer_type='exclusive',
),
)
if self.metrics:
self.metrics.state("running")
logger.info("Subscriber running...")
drain_end_time = None
while self.running or self.draining:
# Start drain timeout when entering drain mode
if self.draining and drain_end_time is None:
drain_end_time = time.time() + self.drain_timeout
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
# Stop accepting new messages during drain
# Note: Not all backends support pausing message listeners
if self.consumer and hasattr(self.consumer, 'pause_message_listener'):
try:
self.consumer.pause_message_listener()
except Exception:
# Not all consumers support message listeners
pass
# Check drain timeout
if self.draining and drain_end_time and time.time() > drain_end_time:
async with self.lock:
total_pending = sum(
q.qsize() for q in
list(self.q.values()) + list(self.full.values())
)
if total_pending > 0:
logger.warning(f"Drain timeout reached with {total_pending} messages in queues")
self.draining = False
break
# Check if we can exit drain mode
if self.draining:
async with self.lock:
all_empty = all(
q.empty() for q in
list(self.q.values()) + list(self.full.values())
)
if all_empty and len(self.pending_acks) == 0:
logger.info("Subscriber queues drained successfully")
self.draining = False
break
# Process messages only if not draining
if not self.draining:
try:
loop = asyncio.get_event_loop()
msg = await loop.run_in_executor(
self.executor,
lambda: self.consumer.receive(
timeout_millis=250
),
)
except Exception as e:
# Handle timeout from any backend
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
continue
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
raise e
if self.metrics:
self.metrics.received()
# Process the message with deferred acknowledgment
await self._process_message(msg)
else:
# During draining, just wait for queues to empty
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Subscriber exception: {e}", exc_info=True)
finally:
# Negative acknowledge any pending messages
for msg in self.pending_acks.values():
try:
self.consumer.negative_acknowledge(msg)
except Exception:
pass # Consumer already closed or error
self.pending_acks.clear()
if self.consumer:
if hasattr(self.consumer, 'unsubscribe'):
try:
self.consumer.unsubscribe()
except Exception:
pass # Already closed or error
try:
self.consumer.close()
except Exception:
pass # Already closed or error
self.consumer = None
if self.executor:
self.executor.shutdown(wait=False)
self.executor = None
if self.metrics:
self.metrics.state("stopped")
if not self.running and not self.draining:
return
# Sleep before retry
await asyncio.sleep(1)
async def subscribe(self, id):
async with self.lock:
q = asyncio.Queue(maxsize=self.max_size)
self.q[id] = q
return q
async def unsubscribe(self, id):
async with self.lock:
if id in self.q:
# self.q[id].shutdown(immediate=True)
del self.q[id]
async def subscribe_all(self, id):
async with self.lock:
q = asyncio.Queue(maxsize=self.max_size)
self.full[id] = q
return q
async def unsubscribe_all(self, id):
async with self.lock:
if id in self.full:
# self.full[id].shutdown(immediate=True)
del self.full[id]
async def _process_message(self, msg):
"""Process a single message with deferred acknowledgment"""
# Store message for later acknowledgment
msg_id = str(uuid.uuid4())
self.pending_acks[msg_id] = msg
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
delivery_success = False
has_matching_waiter = False
async with self.lock:
# Deliver to specific subscribers
if id in self.q:
has_matching_waiter = True
delivery_success = await self._deliver_to_queue(
self.q[id], value
)
# Deliver to all subscribers
for q in self.full.values():
has_matching_waiter = True
if await self._deliver_to_queue(q, value):
delivery_success = True
# Always acknowledge the message to prevent redelivery storms
# on shared topics. Negative acknowledging orphaned messages
# (no matching waiter) causes immediate redelivery to all
# subscribers, none of whom can handle it either.
self.consumer.acknowledge(msg)
del self.pending_acks[msg_id]
if not delivery_success:
if not has_matching_waiter:
# Message arrived for a waiter that no longer exists
# (likely due to client disconnect or timeout)
logger.debug(
f"Discarding orphaned message with id={id} - "
"no matching waiter"
)
else:
# Delivery failed (e.g., queue full with drop_new strategy)
logger.debug(
f"Message with id={id} dropped due to backpressure"
)
async def _deliver_to_queue(self, queue, value):
"""Deliver message to queue with backpressure handling"""
try:
if self.backpressure_strategy == "block":
# Block until space available (no timeout)
await queue.put(value)
return True
elif self.backpressure_strategy == "drop_oldest":
# Drop oldest message if queue full
if queue.full():
try:
queue.get_nowait()
if self.metrics:
self.metrics.dropped()
except asyncio.QueueEmpty:
pass
await queue.put(value)
return True
elif self.backpressure_strategy == "drop_new":
# Drop new message if queue full
if queue.full():
if self.metrics:
self.metrics.dropped()
return False
await queue.put(value)
return True
except Exception as e:
logger.error(f"Failed to deliver message: {e}")
return False