diff --git a/tests/unit/test_base/test_subscriber_graceful_shutdown.py b/tests/unit/test_base/test_subscriber_graceful_shutdown.py index 0587e3d6..ec14f66b 100644 --- a/tests/unit/test_base/test_subscriber_graceful_shutdown.py +++ b/tests/unit/test_base/test_subscriber_graceful_shutdown.py @@ -61,23 +61,21 @@ async def test_subscriber_deferred_acknowledgment_success(): max_size=10, backpressure_strategy="block" ) - - # Start subscriber to initialize consumer - await subscriber.start() - + subscriber.consumer = mock_consumer + # Create queue for subscription queue = await subscriber.subscribe("test-queue") - + # Create mock message with matching queue name msg = create_mock_message("test-queue", {"data": "test"}) - + # Process message await subscriber._process_message(msg) - + # Should acknowledge successful delivery mock_consumer.acknowledge.assert_called_once_with(msg) mock_consumer.negative_acknowledge.assert_not_called() - + # Message should be in queue assert not queue.empty() received_msg = await queue.get() @@ -108,9 +106,7 @@ async def test_subscriber_dropped_message_still_acks(): max_size=1, # Very small queue backpressure_strategy="drop_new" ) - - # Start subscriber to initialize consumer - await subscriber.start() + subscriber.consumer = mock_consumer # Create queue and fill it queue = await subscriber.subscribe("test-queue") @@ -151,9 +147,7 @@ async def test_subscriber_orphaned_message_acks(): max_size=10, backpressure_strategy="block" ) - - # Start subscriber to initialize consumer - await subscriber.start() + subscriber.consumer = mock_consumer # Don't create any queues - message will be orphaned # This simulates a response arriving after the waiter has unsubscribed @@ -189,9 +183,7 @@ async def test_subscriber_backpressure_strategies(): max_size=2, backpressure_strategy="drop_oldest" ) - - # Start subscriber to initialize consumer - await subscriber.start() + subscriber.consumer = mock_consumer queue = await subscriber.subscribe("test-queue") diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py index 03244b73..59c7f2b5 100644 --- a/tests/unit/test_concurrency/test_consumer_concurrency.py +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -81,9 +81,8 @@ class TestTaskGroupConcurrency: # Track how many consume_from_queue calls are made call_count = 0 - original_running = True - async def mock_consume(backend_consumer): + async def mock_consume(backend_consumer, executor=None): nonlocal call_count call_count += 1 # Wait a bit to let all tasks start, then signal stop @@ -107,7 +106,7 @@ class TestTaskGroupConcurrency: consumer = _make_consumer(concurrency=1) call_count = 0 - async def mock_consume(backend_consumer): + async def mock_consume(backend_consumer, executor=None): nonlocal call_count call_count += 1 await asyncio.sleep(0.01) @@ -294,9 +293,8 @@ class TestPollTimeout: raise type('Timeout', (Exception,), {})("timeout") mock_pulsar_consumer.receive = capture_receive - consumer.consumer = mock_pulsar_consumer - await consumer.consume_from_queue() + await consumer.consume_from_queue(mock_pulsar_consumer) assert received_kwargs.get("timeout_millis") == 100 diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index c805bffa..4f04df16 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -94,7 +94,6 @@ class AsyncProcessor: metrics = config_consumer_metrics, start_of_messages = False, - consumer_type = 'exclusive', ) self.running = True diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 4f8c9de5..b6c28bbe 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -12,6 +12,7 @@ import asyncio import time import logging +from concurrent.futures import ThreadPoolExecutor from .. exceptions import TooManyRequests @@ -110,29 +111,37 @@ class Consumer: logger.info(f"Starting {self.concurrency} receiver threads") # Create one backend consumer per concurrent task. - # Each gets its own connection — required for backends - # like RabbitMQ where connections are not thread-safe. + # Each gets its own connection and dedicated thread — + # required for backends like RabbitMQ where connections + # are not thread-safe (pika BlockingConnection must be + # used from a single thread). consumers = [] + executors = [] for i in range(self.concurrency): try: logger.info(f"Subscribing to topic: {self.topic} (worker {i})") - c = await asyncio.to_thread( - self.backend.create_consumer, - topic = self.topic, - subscription = self.subscriber, - schema = self.schema, - initial_position = initial_pos, - consumer_type = self.consumer_type, + executor = ThreadPoolExecutor(max_workers=1) + loop = asyncio.get_event_loop() + c = await loop.run_in_executor( + executor, + lambda: self.backend.create_consumer( + topic = self.topic, + subscription = self.subscriber, + schema = self.schema, + initial_position = initial_pos, + consumer_type = self.consumer_type, + ), ) consumers.append(c) + executors.append(executor) logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})") except Exception as e: logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True) raise async with asyncio.TaskGroup() as tg: - for c in consumers: - tg.create_task(self.consume_from_queue(c)) + for c, ex in zip(consumers, executors): + tg.create_task(self.consume_from_queue(c, ex)) if self.metrics: self.metrics.state("stopped") @@ -146,7 +155,10 @@ class Consumer: c.close() except Exception: pass + for ex in executors: + ex.shutdown(wait=False) consumers = [] + executors = [] await asyncio.sleep(self.reconnect_time) continue @@ -157,15 +169,18 @@ class Consumer: c.close() except Exception: pass + for ex in executors: + ex.shutdown(wait=False) - async def consume_from_queue(self, consumer): + async def consume_from_queue(self, consumer, executor=None): + loop = asyncio.get_event_loop() while self.running: try: - msg = await asyncio.to_thread( - consumer.receive, - timeout_millis=100 + msg = await loop.run_in_executor( + executor, + lambda: consumer.receive(timeout_millis=100), ) except Exception as e: # Handle timeout from any backend @@ -173,10 +188,11 @@ class Consumer: continue raise e - await self.handle_one_from_queue(msg, consumer) + await self.handle_one_from_queue(msg, consumer, executor) - async def handle_one_from_queue(self, msg, consumer): + async def handle_one_from_queue(self, msg, consumer, executor=None): + loop = asyncio.get_event_loop() expiry = time.time() + self.rate_limit_timeout # This loop is for retry on rate-limit / resource limits @@ -187,8 +203,11 @@ class Consumer: logger.warning("Gave up waiting for rate-limit retry") # Message failed to be processed, this causes it to - # be retried - consumer.negative_acknowledge(msg) + # be retried. Ack on the consumer's dedicated thread + # (pika is not thread-safe). + await loop.run_in_executor( + executor, lambda: consumer.negative_acknowledge(msg) + ) if self.metrics: self.metrics.process("error") @@ -210,8 +229,11 @@ class Consumer: logger.debug("Message processed successfully") - # Acknowledge successful processing of the message - consumer.acknowledge(msg) + # Acknowledge on the consumer's dedicated thread + # (pika is not thread-safe) + await loop.run_in_executor( + executor, lambda: consumer.acknowledge(msg) + ) if self.metrics: self.metrics.process("success") @@ -237,8 +259,10 @@ class Consumer: logger.error(f"Message processing exception: {e}", exc_info=True) # Message failed to be processed, this causes it to - # be retried - consumer.negative_acknowledge(msg) + # be retried. Ack on the consumer's dedicated thread. + await loop.run_in_executor( + executor, lambda: consumer.negative_acknowledge(msg) + ) if self.metrics: self.metrics.process("error") diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index 36948131..6cb234b1 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -7,6 +7,7 @@ import asyncio import time import logging import uuid +from concurrent.futures import ThreadPoolExecutor # Module logger logger = logging.getLogger(__name__) @@ -38,6 +39,7 @@ class Subscriber: self.pending_acks = {} # Track messages awaiting delivery self.consumer = None + self.executor = None def __del__(self): @@ -45,15 +47,6 @@ class Subscriber: async def start(self): - # Create consumer via backend - self.consumer = await asyncio.to_thread( - self.backend.create_consumer, - topic=self.topic, - subscription=self.subscription, - schema=self.schema, - consumer_type='exclusive', - ) - self.task = asyncio.create_task(self.run()) async def stop(self): @@ -80,6 +73,21 @@ class Subscriber: try: + # Create consumer and dedicated thread if needed + # (first run or after failure) + if self.consumer is None: + self.executor = ThreadPoolExecutor(max_workers=1) + loop = asyncio.get_event_loop() + self.consumer = await loop.run_in_executor( + self.executor, + lambda: self.backend.create_consumer( + topic=self.topic, + subscription=self.subscription, + schema=self.schema, + consumer_type='exclusive', + ), + ) + if self.metrics: self.metrics.state("running") @@ -128,9 +136,12 @@ class Subscriber: # Process messages only if not draining if not self.draining: try: - msg = await asyncio.to_thread( - self.consumer.receive, - timeout_millis=250 + loop = asyncio.get_event_loop() + msg = await loop.run_in_executor( + self.executor, + lambda: self.consumer.receive( + timeout_millis=250 + ), ) except Exception as e: # Handle timeout from any backend @@ -172,15 +183,18 @@ class Subscriber: except Exception: pass # Already closed or error self.consumer = None - - + + if self.executor: + self.executor.shutdown(wait=False) + self.executor = None + if self.metrics: self.metrics.state("stopped") if not self.running and not self.draining: return - - # If handler drops out, sleep a retry + + # Sleep before retry await asyncio.sleep(1) async def subscribe(self, id): diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py index 0bc5d7e3..c793f9ca 100755 --- a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py @@ -24,7 +24,7 @@ class Service(ToolService): **params ) - self.register_config_handler(self.on_mcp_config, types=["mcp-tool"]) + self.register_config_handler(self.on_mcp_config, types=["mcp"]) self.mcp_services = {} diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index 3bc4e9b6..d6390805 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -108,7 +108,7 @@ class Processor(AsyncProcessor): flow_config = self, ) - self.register_config_handler(self.on_knowledge_config, types=["kg-core"]) + self.register_config_handler(self.on_knowledge_config, types=["flow"]) self.flows = {} diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index 15cc97fa..c735a550 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -246,7 +246,10 @@ class Processor(AsyncProcessor): taskgroup = self.taskgroup, ) - self.register_config_handler(self.on_librarian_config, types=["librarian"]) + self.register_config_handler( + self.on_librarian_config, + types=["flow", "active-flow"], + ) self.flows = {} diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index f120a812..3e0b610c 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -40,7 +40,7 @@ class Processor(FlowProcessor): } ) - self.register_config_handler(self.on_cost_config, types=["token-costs"]) + self.register_config_handler(self.on_cost_config, types=["token-cost"]) self.register_specification( ConsumerSpec(