diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index 33f1229d..83969fdd 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -49,7 +49,8 @@ class TestDispatcherManager: assert manager.prefix == "api-gateway" # default prefix assert manager.flows == {} assert manager.dispatchers == {} - + assert isinstance(manager.dispatcher_lock, asyncio.Lock) + # Verify manager was added as handler to config receiver mock_config_receiver.add_handler.assert_called_once_with(manager) @@ -543,18 +544,99 @@ class TestDispatcherManager: mock_backend = Mock() mock_config_receiver = Mock() manager = DispatcherManager(mock_backend, mock_config_receiver) - + # Setup test flow with interface but unsupported kind manager.flows["test_flow"] = { "interfaces": { "invalid-kind": {"request": "req", "response": "resp"} } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \ patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers: mock_rr_dispatchers.__contains__.return_value = False mock_sender_dispatchers.__contains__.return_value = False - + with pytest.raises(RuntimeError, match="Invalid kind"): - await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") \ No newline at end of file + await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") + + @pytest.mark.asyncio + async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self): + """Concurrent calls for the same service must create exactly one dispatcher. + + Before the fix, await dispatcher.start() yielded to the event loop and + multiple coroutines could all pass the 'key not in self.dispatchers' check + before any of them wrote the result back, creating duplicate Pulsar consumers. + """ + mock_backend = Mock() + mock_config_receiver = Mock() + manager = DispatcherManager(mock_backend, mock_config_receiver) + + async def slow_start(): + # Yield to the event loop so other coroutines get a chance to run, + # reproducing the window that caused the original race condition. + await asyncio.sleep(0) + + with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: + mock_dispatcher_class = Mock() + mock_dispatcher = Mock() + mock_dispatcher.start = AsyncMock(side_effect=slow_start) + mock_dispatcher.process = AsyncMock(return_value="result") + mock_dispatcher_class.return_value = mock_dispatcher + mock_dispatchers.__getitem__.return_value = mock_dispatcher_class + + results = await asyncio.gather(*[ + manager.invoke_global_service("data", "responder", "config") + for _ in range(5) + ]) + + assert mock_dispatcher_class.call_count == 1, ( + "Dispatcher class instantiated more than once — duplicate consumer bug" + ) + assert mock_dispatcher.start.call_count == 1 + assert manager.dispatchers[(None, "config")] is mock_dispatcher + assert all(r == "result" for r in results) + + @pytest.mark.asyncio + async def test_invoke_flow_service_concurrent_calls_create_single_dispatcher(self): + """Concurrent calls for the same flow+kind must create exactly one dispatcher. + + invoke_flow_service has the same check-then-create pattern as + invoke_global_service and is protected by the same dispatcher_lock. + """ + mock_backend = Mock() + mock_config_receiver = Mock() + manager = DispatcherManager(mock_backend, mock_config_receiver) + + manager.flows["test_flow"] = { + "interfaces": { + "agent": { + "request": "agent_request_queue", + "response": "agent_response_queue", + } + } + } + + async def slow_start(): + await asyncio.sleep(0) + + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers: + mock_dispatcher_class = Mock() + mock_dispatcher = Mock() + mock_dispatcher.start = AsyncMock(side_effect=slow_start) + mock_dispatcher.process = AsyncMock(return_value="result") + mock_dispatcher_class.return_value = mock_dispatcher + mock_rr_dispatchers.__getitem__.return_value = mock_dispatcher_class + mock_rr_dispatchers.__contains__.return_value = True + + results = await asyncio.gather(*[ + manager.invoke_flow_service("data", "responder", "test_flow", "agent") + for _ in range(5) + ]) + + assert mock_dispatcher_class.call_count == 1, ( + "Dispatcher class instantiated more than once — duplicate consumer bug" + ) + assert mock_dispatcher.start.call_count == 1 + assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher + assert all(r == "result" for r in results) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index a4bf8de9..ef3d5507 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -118,6 +118,7 @@ class DispatcherManager: self.flows = {} self.dispatchers = {} + self.dispatcher_lock = asyncio.Lock() async def start_flow(self, id, flow): logger.info(f"Starting flow {id}") @@ -165,30 +166,28 @@ class DispatcherManager: key = (None, kind) - if key in self.dispatchers: - return await self.dispatchers[key].process(data, responder) + if key not in self.dispatchers: + async with self.dispatcher_lock: + if key not in self.dispatchers: + request_queue = None + response_queue = None + if kind in self.queue_overrides: + request_queue = self.queue_overrides[kind].get("request") + response_queue = self.queue_overrides[kind].get("response") - # Get queue overrides if specified for this service - request_queue = None - response_queue = None - if kind in self.queue_overrides: - request_queue = self.queue_overrides[kind].get("request") - response_queue = self.queue_overrides[kind].get("response") + dispatcher = global_dispatchers[kind]( + backend = self.backend, + timeout = 120, + consumer = f"{self.prefix}-{kind}-request", + subscriber = f"{self.prefix}-{kind}-request", + request_queue = request_queue, + response_queue = response_queue, + ) - dispatcher = global_dispatchers[kind]( - backend = self.backend, - timeout = 120, - consumer = f"{self.prefix}-{kind}-request", - subscriber = f"{self.prefix}-{kind}-request", - request_queue = request_queue, - response_queue = response_queue, - ) + await dispatcher.start() + self.dispatchers[key] = dispatcher - await dispatcher.start() - - self.dispatchers[key] = dispatcher - - return await dispatcher.process(data, responder) + return await self.dispatchers[key].process(data, responder) def dispatch_flow_import(self): return self.process_flow_import @@ -299,36 +298,35 @@ class DispatcherManager: key = (flow, kind) - if key in self.dispatchers: - return await self.dispatchers[key].process(data, responder) + if key not in self.dispatchers: + async with self.dispatcher_lock: + if key not in self.dispatchers: + intf_defs = self.flows[flow]["interfaces"] - intf_defs = self.flows[flow]["interfaces"] + if kind not in intf_defs: + raise RuntimeError("This kind not supported by flow") - if kind not in intf_defs: - raise RuntimeError("This kind not supported by flow") + qconfig = intf_defs[kind] - qconfig = intf_defs[kind] + if kind in request_response_dispatchers: + dispatcher = request_response_dispatchers[kind]( + backend = self.backend, + request_queue = qconfig["request"], + response_queue = qconfig["response"], + timeout = 120, + consumer = f"{self.prefix}-{flow}-{kind}-request", + subscriber = f"{self.prefix}-{flow}-{kind}-request", + ) + elif kind in sender_dispatchers: + dispatcher = sender_dispatchers[kind]( + backend = self.backend, + queue = qconfig, + ) + else: + raise RuntimeError("Invalid kind") - if kind in request_response_dispatchers: - dispatcher = request_response_dispatchers[kind]( - backend = self.backend, - request_queue = qconfig["request"], - response_queue = qconfig["response"], - timeout = 120, - consumer = f"{self.prefix}-{flow}-{kind}-request", - subscriber = f"{self.prefix}-{flow}-{kind}-request", - ) - elif kind in sender_dispatchers: - dispatcher = sender_dispatchers[kind]( - backend = self.backend, - queue = qconfig, - ) - else: - raise RuntimeError("Invalid kind") - - await dispatcher.start() + await dispatcher.start() + self.dispatchers[key] = dispatcher - self.dispatchers[key] = dispatcher - - return await dispatcher.process(data, responder) + return await self.dispatchers[key].process(data, responder)