trustgraph/tests/unit/test_concurrency/test_dispatcher_semaphore.py
cybermaggedon e57f4669e1
Update rev-gateway for IAM integration (#940)
service.py:
- Constructor takes **config (same pattern as api-gateway) instead
  of individual args
- Creates IamAuth and calls await self.auth.start() before the
  message loop
- Passes auth to both ConfigReceiver and MessageDispatcher
- Uses add_pubsub_args / add_logging_args instead of hand-rolled
  Pulsar args
- Passes timeout through

dispatcher.py:
- Accepts auth and timeout parameters
- Passes both to DispatcherManager — fixes the missing auth argument
  that would have crashed on startup

The remote end's requests now go through the same IAM authentication
path as api-gateway. Token validation, workspace resolution, and
permissions all work identically regardless of which direction
initiated the connection.

Fixed tests — the test now passes auth and timeout to MessageDispatcher
and verifies they're forwarded to DispatcherManager.

Update rev gateway dispatcher to align with IAM.  A "token" parameter
must be passed with each message.

Fix websocket relay to align with rev-gateway changes, conforms to
the api-gateway protocol.
2026-05-19 21:45:43 +01:00

139 lines
4.4 KiB
Python

"""
Tests for MessageDispatcher semaphore-based concurrency enforcement.
Verifies that the dispatcher limits concurrent message processing to
max_workers via asyncio.Semaphore.
"""
import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestSemaphoreEnforcement:
@pytest.mark.asyncio
async def test_semaphore_limits_concurrent_processing(self):
"""Only max_workers messages should be processed concurrently."""
max_workers = 2
dispatcher = MessageDispatcher(max_workers=max_workers)
concurrent_count = 0
max_concurrent = 0
processing_event = asyncio.Event()
async def slow_process(message, sender):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.05)
concurrent_count -= 1
dispatcher._process_message = slow_process
sender = AsyncMock()
# Launch more tasks than max_workers
messages = [
{"id": f"msg-{i}", "service": "test", "request": {}}
for i in range(5)
]
tasks = [
asyncio.create_task(dispatcher.handle_message(m, sender))
for m in messages
]
await asyncio.gather(*tasks)
# At no point should more than max_workers have been active
assert max_concurrent <= max_workers
@pytest.mark.asyncio
async def test_semaphore_value_matches_max_workers(self):
for n in [1, 5, 20]:
dispatcher = MessageDispatcher(max_workers=n)
assert dispatcher.semaphore._value == n
@pytest.mark.asyncio
async def test_active_tasks_tracked(self):
"""Active tasks should be added/removed during processing."""
dispatcher = MessageDispatcher(max_workers=5)
task_was_tracked = False
original_process = dispatcher._process_message
async def tracking_process(message, sender):
nonlocal task_was_tracked
# During processing, our task should be in active_tasks
if len(dispatcher.active_tasks) > 0:
task_was_tracked = True
dispatcher._process_message = tracking_process
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}},
AsyncMock(),
)
assert task_was_tracked
# After completion, task should be discarded
assert len(dispatcher.active_tasks) == 0
@pytest.mark.asyncio
async def test_semaphore_released_on_error(self):
"""Semaphore should be released even if processing raises."""
dispatcher = MessageDispatcher(max_workers=2)
async def failing_process(message, sender):
raise RuntimeError("process failed")
dispatcher._process_message = failing_process
# Should not deadlock — semaphore must be released on error
with pytest.raises(RuntimeError):
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}},
AsyncMock(),
)
# Semaphore should be back at max
assert dispatcher.semaphore._value == 2
@pytest.mark.asyncio
async def test_single_worker_serializes_processing(self):
"""With max_workers=1, messages are processed one at a time."""
dispatcher = MessageDispatcher(max_workers=1)
order = []
async def ordered_process(message, sender):
msg_id = message["id"]
order.append(f"start-{msg_id}")
await asyncio.sleep(0.02)
order.append(f"end-{msg_id}")
dispatcher._process_message = ordered_process
sender = AsyncMock()
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
await asyncio.gather(*tasks)
# With semaphore=1, each message should complete before next starts
# Check that no two "start" entries appear without an intervening "end"
active = 0
max_active = 0
for event in order:
if event.startswith("start"):
active += 1
max_active = max(max_active, active)
elif event.startswith("end"):
active -= 1
assert max_active == 1