mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Subscriber resilience and RabbitMQ fixes (#765)
Subscriber resilience: recreate consumer after connection failure - Move consumer creation from Subscriber.start() into the run() loop, matching the pattern used by Consumer. If the connection drops and the consumer is closed in the finally block, the loop now recreates it on the next iteration instead of spinning forever on a None consumer. Consumer thread safety: - Dedicated ThreadPoolExecutor per consumer so all pika operations (create, receive, acknowledge, negative_acknowledge) run on the same thread — pika BlockingConnection is not thread-safe - Applies to both Consumer and Subscriber classes Config handler type audit — fix four mismatched type registrations: - librarian: was ["librarian"] (non-existent type), now ["flow", "active-flow"] (matches config["flow"] that the handler reads) - cores/service: was ["kg-core"], now ["flow"] (reads config["flow"]) - metering/counter: was ["token-costs"], now ["token-cost"] (singular) - agent/mcp_tool: was ["mcp-tool"], now ["mcp"] (reads config["mcp"]) Update tests
This commit is contained in:
parent
ddd4bd7790
commit
c20e6540ec
9 changed files with 96 additions and 66 deletions
|
|
@ -61,23 +61,21 @@ async def test_subscriber_deferred_acknowledgment_success():
|
||||||
max_size=10,
|
max_size=10,
|
||||||
backpressure_strategy="block"
|
backpressure_strategy="block"
|
||||||
)
|
)
|
||||||
|
subscriber.consumer = mock_consumer
|
||||||
# Start subscriber to initialize consumer
|
|
||||||
await subscriber.start()
|
|
||||||
|
|
||||||
# Create queue for subscription
|
# Create queue for subscription
|
||||||
queue = await subscriber.subscribe("test-queue")
|
queue = await subscriber.subscribe("test-queue")
|
||||||
|
|
||||||
# Create mock message with matching queue name
|
# Create mock message with matching queue name
|
||||||
msg = create_mock_message("test-queue", {"data": "test"})
|
msg = create_mock_message("test-queue", {"data": "test"})
|
||||||
|
|
||||||
# Process message
|
# Process message
|
||||||
await subscriber._process_message(msg)
|
await subscriber._process_message(msg)
|
||||||
|
|
||||||
# Should acknowledge successful delivery
|
# Should acknowledge successful delivery
|
||||||
mock_consumer.acknowledge.assert_called_once_with(msg)
|
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||||
mock_consumer.negative_acknowledge.assert_not_called()
|
mock_consumer.negative_acknowledge.assert_not_called()
|
||||||
|
|
||||||
# Message should be in queue
|
# Message should be in queue
|
||||||
assert not queue.empty()
|
assert not queue.empty()
|
||||||
received_msg = await queue.get()
|
received_msg = await queue.get()
|
||||||
|
|
@ -108,9 +106,7 @@ async def test_subscriber_dropped_message_still_acks():
|
||||||
max_size=1, # Very small queue
|
max_size=1, # Very small queue
|
||||||
backpressure_strategy="drop_new"
|
backpressure_strategy="drop_new"
|
||||||
)
|
)
|
||||||
|
subscriber.consumer = mock_consumer
|
||||||
# Start subscriber to initialize consumer
|
|
||||||
await subscriber.start()
|
|
||||||
|
|
||||||
# Create queue and fill it
|
# Create queue and fill it
|
||||||
queue = await subscriber.subscribe("test-queue")
|
queue = await subscriber.subscribe("test-queue")
|
||||||
|
|
@ -151,9 +147,7 @@ async def test_subscriber_orphaned_message_acks():
|
||||||
max_size=10,
|
max_size=10,
|
||||||
backpressure_strategy="block"
|
backpressure_strategy="block"
|
||||||
)
|
)
|
||||||
|
subscriber.consumer = mock_consumer
|
||||||
# Start subscriber to initialize consumer
|
|
||||||
await subscriber.start()
|
|
||||||
|
|
||||||
# Don't create any queues - message will be orphaned
|
# Don't create any queues - message will be orphaned
|
||||||
# This simulates a response arriving after the waiter has unsubscribed
|
# This simulates a response arriving after the waiter has unsubscribed
|
||||||
|
|
@ -189,9 +183,7 @@ async def test_subscriber_backpressure_strategies():
|
||||||
max_size=2,
|
max_size=2,
|
||||||
backpressure_strategy="drop_oldest"
|
backpressure_strategy="drop_oldest"
|
||||||
)
|
)
|
||||||
|
subscriber.consumer = mock_consumer
|
||||||
# Start subscriber to initialize consumer
|
|
||||||
await subscriber.start()
|
|
||||||
|
|
||||||
queue = await subscriber.subscribe("test-queue")
|
queue = await subscriber.subscribe("test-queue")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -81,9 +81,8 @@ class TestTaskGroupConcurrency:
|
||||||
|
|
||||||
# Track how many consume_from_queue calls are made
|
# Track how many consume_from_queue calls are made
|
||||||
call_count = 0
|
call_count = 0
|
||||||
original_running = True
|
|
||||||
|
|
||||||
async def mock_consume(backend_consumer):
|
async def mock_consume(backend_consumer, executor=None):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
# Wait a bit to let all tasks start, then signal stop
|
# Wait a bit to let all tasks start, then signal stop
|
||||||
|
|
@ -107,7 +106,7 @@ class TestTaskGroupConcurrency:
|
||||||
consumer = _make_consumer(concurrency=1)
|
consumer = _make_consumer(concurrency=1)
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_consume(backend_consumer):
|
async def mock_consume(backend_consumer, executor=None):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
@ -294,9 +293,8 @@ class TestPollTimeout:
|
||||||
raise type('Timeout', (Exception,), {})("timeout")
|
raise type('Timeout', (Exception,), {})("timeout")
|
||||||
|
|
||||||
mock_pulsar_consumer.receive = capture_receive
|
mock_pulsar_consumer.receive = capture_receive
|
||||||
consumer.consumer = mock_pulsar_consumer
|
|
||||||
|
|
||||||
await consumer.consume_from_queue()
|
await consumer.consume_from_queue(mock_pulsar_consumer)
|
||||||
|
|
||||||
assert received_kwargs.get("timeout_millis") == 100
|
assert received_kwargs.get("timeout_millis") == 100
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,6 @@ class AsyncProcessor:
|
||||||
metrics = config_consumer_metrics,
|
metrics = config_consumer_metrics,
|
||||||
|
|
||||||
start_of_messages = False,
|
start_of_messages = False,
|
||||||
consumer_type = 'exclusive',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from .. exceptions import TooManyRequests
|
from .. exceptions import TooManyRequests
|
||||||
|
|
||||||
|
|
@ -110,29 +111,37 @@ class Consumer:
|
||||||
logger.info(f"Starting {self.concurrency} receiver threads")
|
logger.info(f"Starting {self.concurrency} receiver threads")
|
||||||
|
|
||||||
# Create one backend consumer per concurrent task.
|
# Create one backend consumer per concurrent task.
|
||||||
# Each gets its own connection — required for backends
|
# Each gets its own connection and dedicated thread —
|
||||||
# like RabbitMQ where connections are not thread-safe.
|
# required for backends like RabbitMQ where connections
|
||||||
|
# are not thread-safe (pika BlockingConnection must be
|
||||||
|
# used from a single thread).
|
||||||
consumers = []
|
consumers = []
|
||||||
|
executors = []
|
||||||
for i in range(self.concurrency):
|
for i in range(self.concurrency):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Subscribing to topic: {self.topic} (worker {i})")
|
logger.info(f"Subscribing to topic: {self.topic} (worker {i})")
|
||||||
c = await asyncio.to_thread(
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
self.backend.create_consumer,
|
loop = asyncio.get_event_loop()
|
||||||
topic = self.topic,
|
c = await loop.run_in_executor(
|
||||||
subscription = self.subscriber,
|
executor,
|
||||||
schema = self.schema,
|
lambda: self.backend.create_consumer(
|
||||||
initial_position = initial_pos,
|
topic = self.topic,
|
||||||
consumer_type = self.consumer_type,
|
subscription = self.subscriber,
|
||||||
|
schema = self.schema,
|
||||||
|
initial_position = initial_pos,
|
||||||
|
consumer_type = self.consumer_type,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
consumers.append(c)
|
consumers.append(c)
|
||||||
|
executors.append(executor)
|
||||||
logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})")
|
logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True)
|
logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async with asyncio.TaskGroup() as tg:
|
async with asyncio.TaskGroup() as tg:
|
||||||
for c in consumers:
|
for c, ex in zip(consumers, executors):
|
||||||
tg.create_task(self.consume_from_queue(c))
|
tg.create_task(self.consume_from_queue(c, ex))
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("stopped")
|
self.metrics.state("stopped")
|
||||||
|
|
@ -146,7 +155,10 @@ class Consumer:
|
||||||
c.close()
|
c.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
for ex in executors:
|
||||||
|
ex.shutdown(wait=False)
|
||||||
consumers = []
|
consumers = []
|
||||||
|
executors = []
|
||||||
await asyncio.sleep(self.reconnect_time)
|
await asyncio.sleep(self.reconnect_time)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -157,15 +169,18 @@ class Consumer:
|
||||||
c.close()
|
c.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
for ex in executors:
|
||||||
|
ex.shutdown(wait=False)
|
||||||
|
|
||||||
async def consume_from_queue(self, consumer):
|
async def consume_from_queue(self, consumer, executor=None):
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
while self.running:
|
while self.running:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.to_thread(
|
msg = await loop.run_in_executor(
|
||||||
consumer.receive,
|
executor,
|
||||||
timeout_millis=100
|
lambda: consumer.receive(timeout_millis=100),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle timeout from any backend
|
# Handle timeout from any backend
|
||||||
|
|
@ -173,10 +188,11 @@ class Consumer:
|
||||||
continue
|
continue
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
await self.handle_one_from_queue(msg, consumer)
|
await self.handle_one_from_queue(msg, consumer, executor)
|
||||||
|
|
||||||
async def handle_one_from_queue(self, msg, consumer):
|
async def handle_one_from_queue(self, msg, consumer, executor=None):
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
expiry = time.time() + self.rate_limit_timeout
|
expiry = time.time() + self.rate_limit_timeout
|
||||||
|
|
||||||
# This loop is for retry on rate-limit / resource limits
|
# This loop is for retry on rate-limit / resource limits
|
||||||
|
|
@ -187,8 +203,11 @@ class Consumer:
|
||||||
logger.warning("Gave up waiting for rate-limit retry")
|
logger.warning("Gave up waiting for rate-limit retry")
|
||||||
|
|
||||||
# Message failed to be processed, this causes it to
|
# Message failed to be processed, this causes it to
|
||||||
# be retried
|
# be retried. Ack on the consumer's dedicated thread
|
||||||
consumer.negative_acknowledge(msg)
|
# (pika is not thread-safe).
|
||||||
|
await loop.run_in_executor(
|
||||||
|
executor, lambda: consumer.negative_acknowledge(msg)
|
||||||
|
)
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.process("error")
|
self.metrics.process("error")
|
||||||
|
|
@ -210,8 +229,11 @@ class Consumer:
|
||||||
|
|
||||||
logger.debug("Message processed successfully")
|
logger.debug("Message processed successfully")
|
||||||
|
|
||||||
# Acknowledge successful processing of the message
|
# Acknowledge on the consumer's dedicated thread
|
||||||
consumer.acknowledge(msg)
|
# (pika is not thread-safe)
|
||||||
|
await loop.run_in_executor(
|
||||||
|
executor, lambda: consumer.acknowledge(msg)
|
||||||
|
)
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.process("success")
|
self.metrics.process("success")
|
||||||
|
|
@ -237,8 +259,10 @@ class Consumer:
|
||||||
logger.error(f"Message processing exception: {e}", exc_info=True)
|
logger.error(f"Message processing exception: {e}", exc_info=True)
|
||||||
|
|
||||||
# Message failed to be processed, this causes it to
|
# Message failed to be processed, this causes it to
|
||||||
# be retried
|
# be retried. Ack on the consumer's dedicated thread.
|
||||||
consumer.negative_acknowledge(msg)
|
await loop.run_in_executor(
|
||||||
|
executor, lambda: consumer.negative_acknowledge(msg)
|
||||||
|
)
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.process("error")
|
self.metrics.process("error")
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import asyncio
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -38,6 +39,7 @@ class Subscriber:
|
||||||
self.pending_acks = {} # Track messages awaiting delivery
|
self.pending_acks = {} # Track messages awaiting delivery
|
||||||
|
|
||||||
self.consumer = None
|
self.consumer = None
|
||||||
|
self.executor = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|
||||||
|
|
@ -45,15 +47,6 @@ class Subscriber:
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
# Create consumer via backend
|
|
||||||
self.consumer = await asyncio.to_thread(
|
|
||||||
self.backend.create_consumer,
|
|
||||||
topic=self.topic,
|
|
||||||
subscription=self.subscription,
|
|
||||||
schema=self.schema,
|
|
||||||
consumer_type='exclusive',
|
|
||||||
)
|
|
||||||
|
|
||||||
self.task = asyncio.create_task(self.run())
|
self.task = asyncio.create_task(self.run())
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
|
|
@ -80,6 +73,21 @@ class Subscriber:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
# Create consumer and dedicated thread if needed
|
||||||
|
# (first run or after failure)
|
||||||
|
if self.consumer is None:
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
self.consumer = await loop.run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
lambda: self.backend.create_consumer(
|
||||||
|
topic=self.topic,
|
||||||
|
subscription=self.subscription,
|
||||||
|
schema=self.schema,
|
||||||
|
consumer_type='exclusive',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("running")
|
self.metrics.state("running")
|
||||||
|
|
||||||
|
|
@ -128,9 +136,12 @@ class Subscriber:
|
||||||
# Process messages only if not draining
|
# Process messages only if not draining
|
||||||
if not self.draining:
|
if not self.draining:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.to_thread(
|
loop = asyncio.get_event_loop()
|
||||||
self.consumer.receive,
|
msg = await loop.run_in_executor(
|
||||||
timeout_millis=250
|
self.executor,
|
||||||
|
lambda: self.consumer.receive(
|
||||||
|
timeout_millis=250
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle timeout from any backend
|
# Handle timeout from any backend
|
||||||
|
|
@ -172,15 +183,18 @@ class Subscriber:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Already closed or error
|
pass # Already closed or error
|
||||||
self.consumer = None
|
self.consumer = None
|
||||||
|
|
||||||
|
if self.executor:
|
||||||
|
self.executor.shutdown(wait=False)
|
||||||
|
self.executor = None
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("stopped")
|
self.metrics.state("stopped")
|
||||||
|
|
||||||
if not self.running and not self.draining:
|
if not self.running and not self.draining:
|
||||||
return
|
return
|
||||||
|
|
||||||
# If handler drops out, sleep a retry
|
# Sleep before retry
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def subscribe(self, id):
|
async def subscribe(self, id):
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ class Service(ToolService):
|
||||||
**params
|
**params
|
||||||
)
|
)
|
||||||
|
|
||||||
self.register_config_handler(self.on_mcp_config, types=["mcp-tool"])
|
self.register_config_handler(self.on_mcp_config, types=["mcp"])
|
||||||
|
|
||||||
self.mcp_services = {}
|
self.mcp_services = {}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ class Processor(AsyncProcessor):
|
||||||
flow_config = self,
|
flow_config = self,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.register_config_handler(self.on_knowledge_config, types=["kg-core"])
|
self.register_config_handler(self.on_knowledge_config, types=["flow"])
|
||||||
|
|
||||||
self.flows = {}
|
self.flows = {}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -246,7 +246,10 @@ class Processor(AsyncProcessor):
|
||||||
taskgroup = self.taskgroup,
|
taskgroup = self.taskgroup,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.register_config_handler(self.on_librarian_config, types=["librarian"])
|
self.register_config_handler(
|
||||||
|
self.on_librarian_config,
|
||||||
|
types=["flow", "active-flow"],
|
||||||
|
)
|
||||||
|
|
||||||
self.flows = {}
|
self.flows = {}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ class Processor(FlowProcessor):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.register_config_handler(self.on_cost_config, types=["token-costs"])
|
self.register_config_handler(self.on_cost_config, types=["token-cost"])
|
||||||
|
|
||||||
self.register_specification(
|
self.register_specification(
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue