trustgraph/tests/unit/test_concurrency/test_consumer_concurrency.py
cybermaggedon c20e6540ec
Subscriber resilience and RabbitMQ fixes (#765)
Subscriber resilience: recreate consumer after connection failure

- Move consumer creation from Subscriber.start() into the run() loop,
  matching the pattern used by Consumer. If the connection drops and the
  consumer is closed in the finally block, the loop now recreates it on
  the next iteration instead of spinning forever on a None consumer.

Consumer thread safety:
- Dedicated ThreadPoolExecutor per consumer so all pika operations
  (create, receive, acknowledge, negative_acknowledge) run on the
  same thread — pika BlockingConnection is not thread-safe
- Applies to both Consumer and Subscriber classes

Config handler type audit — fix four mismatched type registrations:
- librarian: was ["librarian"] (non-existent type), now ["flow",
  "active-flow"] (matches config["flow"] that the handler reads)
- cores/service: was ["kg-core"], now ["flow"] (reads
  config["flow"])
- metering/counter: was ["token-costs"], now ["token-cost"]
  (singular)
- agent/mcp_tool: was ["mcp-tool"], now ["mcp"] (reads
  config["mcp"])

Update tests
2026-04-07 14:51:14 +01:00

319 lines
10 KiB
Python

"""
Tests for Consumer concurrency: TaskGroup-based concurrent message processing,
rate-limit retry with backpressure, and message acknowledgement.
"""
import asyncio
import time
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.base.consumer import Consumer
from trustgraph.exceptions import TooManyRequests
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_consumer(
concurrency=1,
handler=None,
rate_limit_retry_time=0.01,
rate_limit_timeout=1,
):
"""Create a Consumer with mocked infrastructure."""
taskgroup = MagicMock()
flow = MagicMock()
backend = MagicMock()
schema = MagicMock()
handler = handler or AsyncMock()
consumer = Consumer(
taskgroup=taskgroup,
flow=flow,
backend=backend,
topic="test-topic",
subscriber="test-sub",
schema=schema,
handler=handler,
rate_limit_retry_time=rate_limit_retry_time,
rate_limit_timeout=rate_limit_timeout,
concurrency=concurrency,
)
return consumer
def _make_msg():
"""Create a mock Pulsar message."""
return MagicMock()
# ---------------------------------------------------------------------------
# Concurrency configuration tests
# ---------------------------------------------------------------------------
class TestConcurrencyConfiguration:
def test_default_concurrency_is_1(self):
consumer = _make_consumer()
assert consumer.concurrency == 1
def test_custom_concurrency(self):
consumer = _make_consumer(concurrency=10)
assert consumer.concurrency == 10
def test_concurrency_stored(self):
for n in [1, 5, 20, 100]:
consumer = _make_consumer(concurrency=n)
assert consumer.concurrency == n
class TestTaskGroupConcurrency:
@pytest.mark.asyncio
async def test_creates_n_concurrent_tasks(self):
"""consumer_run should create exactly N concurrent consume_from_queue tasks."""
concurrency = 5
consumer = _make_consumer(concurrency=concurrency)
# Track how many consume_from_queue calls are made
call_count = 0
async def mock_consume(backend_consumer, executor=None):
nonlocal call_count
call_count += 1
# Wait a bit to let all tasks start, then signal stop
await asyncio.sleep(0.05)
consumer.running = False
consumer.consume_from_queue = mock_consume
# Mock the backend.create_consumer
consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
# Run consumer_run - it will create TaskGroup with N tasks
consumer.running = True
await consumer.consumer_run()
assert call_count == concurrency
@pytest.mark.asyncio
async def test_single_concurrency_creates_one_task(self):
"""With concurrency=1, only one consume_from_queue task is created."""
consumer = _make_consumer(concurrency=1)
call_count = 0
async def mock_consume(backend_consumer, executor=None):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
consumer.running = False
consumer.consume_from_queue = mock_consume
consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
consumer.running = True
await consumer.consumer_run()
assert call_count == 1
# ---------------------------------------------------------------------------
# Rate-limit retry tests
# ---------------------------------------------------------------------------
class TestRateLimitRetry:
@pytest.mark.asyncio
async def test_rate_limit_retries_then_succeeds(self):
"""TooManyRequests should cause retry, then succeed on next attempt."""
call_count = 0
async def handler_with_retry(msg, consumer_ref, flow):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TooManyRequests("rate limited")
# Second call succeeds
consumer = _make_consumer(
handler=handler_with_retry,
rate_limit_retry_time=0.01,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
assert call_count == 2
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@pytest.mark.asyncio
async def test_rate_limit_timeout_negative_acks(self):
"""If rate limit retries exhaust the timeout, message is negative-acked."""
async def always_rate_limited(msg, consumer_ref, flow):
raise TooManyRequests("rate limited")
consumer = _make_consumer(
handler=always_rate_limited,
rate_limit_retry_time=0.01,
rate_limit_timeout=0.05,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
consumer.consumer.acknowledge.assert_not_called()
@pytest.mark.asyncio
async def test_non_rate_limit_error_negative_acks_immediately(self):
"""Non-TooManyRequests errors should negative-ack immediately (no retry)."""
call_count = 0
async def failing_handler(msg, consumer_ref, flow):
nonlocal call_count
call_count += 1
raise ValueError("bad data")
consumer = _make_consumer(handler=failing_handler)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
assert call_count == 1
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
@pytest.mark.asyncio
async def test_successful_message_acknowledged(self):
"""Successfully processed messages are acknowledged."""
consumer = _make_consumer(handler=AsyncMock())
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
# ---------------------------------------------------------------------------
# Metrics integration
# ---------------------------------------------------------------------------
class TestMetricsIntegration:
@pytest.mark.asyncio
async def test_success_metric_on_success(self):
consumer = _make_consumer(handler=AsyncMock())
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
mock_metrics.record_time.return_value.__enter__ = MagicMock()
mock_metrics.record_time.return_value.__exit__ = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.process.assert_called_once_with("success")
@pytest.mark.asyncio
async def test_error_metric_on_failure(self):
async def failing(msg, c, f):
raise ValueError("fail")
consumer = _make_consumer(handler=failing)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.process.assert_called_once_with("error")
@pytest.mark.asyncio
async def test_rate_limit_metric_on_too_many_requests(self):
call_count = 0
async def handler(msg, c, f):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TooManyRequests("limited")
consumer = _make_consumer(
handler=handler,
rate_limit_retry_time=0.01,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
mock_metrics.record_time.return_value.__enter__ = MagicMock()
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.rate_limit.assert_called_once()
# ---------------------------------------------------------------------------
# Poll timeout
# ---------------------------------------------------------------------------
class TestPollTimeout:
@pytest.mark.asyncio
async def test_poll_timeout_is_100ms(self):
"""Consumer receive timeout should be 100ms, not the original 2000ms.
A 2000ms poll timeout means every service adds up to 2s of idle
blocking between message bursts. With many sequential hops in a
query pipeline, this compounds into seconds of unnecessary latency.
100ms keeps responsiveness high without significant CPU overhead.
"""
consumer = _make_consumer()
# Wire up a mock Pulsar consumer that records the receive kwargs
mock_pulsar_consumer = MagicMock()
received_kwargs = {}
def capture_receive(**kwargs):
received_kwargs.update(kwargs)
# Stop after one call
consumer.running = False
raise type('Timeout', (Exception,), {})("timeout")
mock_pulsar_consumer.receive = capture_receive
await consumer.consume_from_queue(mock_pulsar_consumer)
assert received_kwargs.get("timeout_millis") == 100
# ---------------------------------------------------------------------------
# Stop / running flag
# ---------------------------------------------------------------------------
class TestStopBehaviour:
@pytest.mark.asyncio
async def test_stop_sets_running_false(self):
consumer = _make_consumer()
consumer.running = True
await consumer.stop()
assert consumer.running is False
def test_initial_running_state(self):
consumer = _make_consumer()
assert consumer.running is True