From 7daa06e9e4d18a78aaa5b0f50ea0167bd9c79243 Mon Sep 17 00:00:00 2001 From: Alex Jenkins Date: Mon, 6 Apr 2026 10:10:14 +0000 Subject: [PATCH 1/3] Feat: Auto-pull missing Ollama models (#757) * fix deadlink in readme Signed-off-by: Jenkins, Kenneth Alexander * feat: Auto-pull Ollama models Signed-off-by: Jenkins, Kenneth Alexander * fix: Restore namespace __init__.py files for package resolution Signed-off-by: Jenkins, Kenneth Alexander * fix CI Signed-off-by: Jenkins, Kenneth Alexander --- .../test_ollama_processor.py | 2 +- .../trustgraph/embeddings/ollama/processor.py | 28 ++++++++++++++++ .../model/text_completion/ollama/llm.py | 33 +++++++++++++++++-- 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_text_completion/test_ollama_processor.py b/tests/unit/test_text_completion/test_ollama_processor.py index 0bf5e0ab..69baf85f 100644 --- a/tests/unit/test_text_completion/test_ollama_processor.py +++ b/tests/unit/test_text_completion/test_ollama_processor.py @@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - assert processor.default_model == 'gemma2:9b' # default_model + assert processor.default_model == 'granite4:350m' # default_model # Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env) mock_client_class.assert_called_once() diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index a65b4ff7..c63db33c 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -7,6 +7,9 @@ from ... base import EmbeddingsService from ollama import Client import os +import logging + +logger = logging.getLogger(__name__) default_ident = "embeddings" @@ -29,6 +32,28 @@ class Processor(EmbeddingsService): self.client = Client(host=ollama) self.default_model = model + self._checked_models = set() + + def _ensure_model(self, model_name): + """Check if model exists locally, pull it if not.""" + if model_name in self._checked_models: + return + + try: + self.client.show(model_name) + self._checked_models.add(model_name) + except Exception as e: + status_code = getattr(e, 'status_code', None) + if status_code == 404 or "not found" in str(e).lower(): + logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") + try: + self.client.pull(model_name) + self._checked_models.add(model_name) + logger.info(f"Successfully pulled Ollama model '{model_name}'.") + except Exception as pull_e: + logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}") + else: + logger.warning(f"Failed to check Ollama model '{model_name}': {e}") async def on_embeddings(self, texts, model=None): @@ -37,6 +62,9 @@ class Processor(EmbeddingsService): use_model = model or self.default_model + # Ensure the model exists/is pulled + self._ensure_model(use_model) + # Ollama handles batch input efficiently embeds = self.client.embed( model = use_model, diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index 3616e428..f6c5dcb8 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -16,7 +16,7 @@ from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" -default_model = 'gemma2:9b' +default_model = 'granite4:350m' default_temperature = 0.0 default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') @@ -39,11 +39,36 @@ class Processor(LlmService): self.default_model = model self.temperature = temperature self.llm = Client(host=ollama) + self._checked_models = set() + + def _ensure_model(self, model_name): + """Check if model exists locally, pull it if not.""" + if model_name in self._checked_models: + return + + try: + self.llm.show(model_name) + self._checked_models.add(model_name) + except Exception as e: + status_code = getattr(e, 'status_code', None) + if status_code == 404 or "not found" in str(e).lower(): + logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") + try: + self.llm.pull(model_name) + self._checked_models.add(model_name) + logger.info(f"Successfully pulled Ollama model '{model_name}'.") + except Exception as pull_e: + logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}") + else: + logger.warning(f"Failed to check Ollama model '{model_name}': {e}") async def generate_content(self, system, prompt, model=None, temperature=None): # Use provided model or fall back to default model_name = model or self.default_model + + # Ensure the model exists/is pulled + self._ensure_model(model_name) # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature @@ -86,6 +111,10 @@ class Processor(LlmService): async def generate_content_stream(self, system, prompt, model=None, temperature=None): """Stream content generation from Ollama""" model_name = model or self.default_model + + # Ensure the model exists/is pulled + self._ensure_model(model_name) + effective_temperature = temperature if temperature is not None else self.temperature logger.debug(f"Using model (streaming): {model_name}") @@ -142,7 +171,7 @@ class Processor(LlmService): parser.add_argument( '-m', '--model', - default="gemma2", + default="granite4:350m", help=f'LLM model (default: {default_model})' ) From 8f18ba025738e58b624371b685b684d8183c9c0c Mon Sep 17 00:00:00 2001 From: "V.Sreeram" Date: Mon, 6 Apr 2026 15:43:59 +0530 Subject: [PATCH 2/3] fix: prevent duplicate dispatcher creation race condition in invoke_global_service (#715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: prevent duplicate dispatcher creation race condition in invoke_global_service Concurrent coroutines could all pass the `if key in self.dispatchers` check before any of them wrote the result back, because `await dispatcher.start()` yields to the event loop. This caused multiple Pulsar consumers to be created on the same shared subscription, distributing responses round-robin and dropping ~2/3 of them — manifesting as a permanent spinner in the Workbench UI. Apply a double-checked asyncio.Lock in both `invoke_global_service` and `invoke_flow_service` so only one dispatcher is ever created per service key. * test: add concurrent-dispatch tests for race condition fix Add asyncio.gather-based tests that verify invoke_global_service and invoke_flow_service create exactly one dispatcher under concurrent calls, preventing the duplicate Pulsar consumer bug. --- .../test_gateway/test_dispatch_manager.py | 92 +++++++++++++++++- .../trustgraph/gateway/dispatch/manager.py | 94 +++++++++---------- 2 files changed, 133 insertions(+), 53 deletions(-) diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index 33f1229d..83969fdd 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -49,7 +49,8 @@ class TestDispatcherManager: assert manager.prefix == "api-gateway" # default prefix assert manager.flows == {} assert manager.dispatchers == {} - + assert isinstance(manager.dispatcher_lock, asyncio.Lock) + # Verify manager was added as handler to config receiver mock_config_receiver.add_handler.assert_called_once_with(manager) @@ -543,18 +544,99 @@ class TestDispatcherManager: mock_backend = Mock() mock_config_receiver = Mock() manager = DispatcherManager(mock_backend, mock_config_receiver) - + # Setup test flow with interface but unsupported kind manager.flows["test_flow"] = { "interfaces": { "invalid-kind": {"request": "req", "response": "resp"} } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \ patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers: mock_rr_dispatchers.__contains__.return_value = False mock_sender_dispatchers.__contains__.return_value = False - + with pytest.raises(RuntimeError, match="Invalid kind"): - await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") \ No newline at end of file + await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") + + @pytest.mark.asyncio + async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self): + """Concurrent calls for the same service must create exactly one dispatcher. + + Before the fix, await dispatcher.start() yielded to the event loop and + multiple coroutines could all pass the 'key not in self.dispatchers' check + before any of them wrote the result back, creating duplicate Pulsar consumers. + """ + mock_backend = Mock() + mock_config_receiver = Mock() + manager = DispatcherManager(mock_backend, mock_config_receiver) + + async def slow_start(): + # Yield to the event loop so other coroutines get a chance to run, + # reproducing the window that caused the original race condition. + await asyncio.sleep(0) + + with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: + mock_dispatcher_class = Mock() + mock_dispatcher = Mock() + mock_dispatcher.start = AsyncMock(side_effect=slow_start) + mock_dispatcher.process = AsyncMock(return_value="result") + mock_dispatcher_class.return_value = mock_dispatcher + mock_dispatchers.__getitem__.return_value = mock_dispatcher_class + + results = await asyncio.gather(*[ + manager.invoke_global_service("data", "responder", "config") + for _ in range(5) + ]) + + assert mock_dispatcher_class.call_count == 1, ( + "Dispatcher class instantiated more than once — duplicate consumer bug" + ) + assert mock_dispatcher.start.call_count == 1 + assert manager.dispatchers[(None, "config")] is mock_dispatcher + assert all(r == "result" for r in results) + + @pytest.mark.asyncio + async def test_invoke_flow_service_concurrent_calls_create_single_dispatcher(self): + """Concurrent calls for the same flow+kind must create exactly one dispatcher. + + invoke_flow_service has the same check-then-create pattern as + invoke_global_service and is protected by the same dispatcher_lock. + """ + mock_backend = Mock() + mock_config_receiver = Mock() + manager = DispatcherManager(mock_backend, mock_config_receiver) + + manager.flows["test_flow"] = { + "interfaces": { + "agent": { + "request": "agent_request_queue", + "response": "agent_response_queue", + } + } + } + + async def slow_start(): + await asyncio.sleep(0) + + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers: + mock_dispatcher_class = Mock() + mock_dispatcher = Mock() + mock_dispatcher.start = AsyncMock(side_effect=slow_start) + mock_dispatcher.process = AsyncMock(return_value="result") + mock_dispatcher_class.return_value = mock_dispatcher + mock_rr_dispatchers.__getitem__.return_value = mock_dispatcher_class + mock_rr_dispatchers.__contains__.return_value = True + + results = await asyncio.gather(*[ + manager.invoke_flow_service("data", "responder", "test_flow", "agent") + for _ in range(5) + ]) + + assert mock_dispatcher_class.call_count == 1, ( + "Dispatcher class instantiated more than once — duplicate consumer bug" + ) + assert mock_dispatcher.start.call_count == 1 + assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher + assert all(r == "result" for r in results) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index d068ecef..3fdb3b12 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -116,6 +116,7 @@ class DispatcherManager: self.flows = {} self.dispatchers = {} + self.dispatcher_lock = asyncio.Lock() async def start_flow(self, id, flow): logger.info(f"Starting flow {id}") @@ -163,30 +164,28 @@ class DispatcherManager: key = (None, kind) - if key in self.dispatchers: - return await self.dispatchers[key].process(data, responder) + if key not in self.dispatchers: + async with self.dispatcher_lock: + if key not in self.dispatchers: + request_queue = None + response_queue = None + if kind in self.queue_overrides: + request_queue = self.queue_overrides[kind].get("request") + response_queue = self.queue_overrides[kind].get("response") - # Get queue overrides if specified for this service - request_queue = None - response_queue = None - if kind in self.queue_overrides: - request_queue = self.queue_overrides[kind].get("request") - response_queue = self.queue_overrides[kind].get("response") + dispatcher = global_dispatchers[kind]( + backend = self.backend, + timeout = 120, + consumer = f"{self.prefix}-{kind}-request", + subscriber = f"{self.prefix}-{kind}-request", + request_queue = request_queue, + response_queue = response_queue, + ) - dispatcher = global_dispatchers[kind]( - backend = self.backend, - timeout = 120, - consumer = f"{self.prefix}-{kind}-request", - subscriber = f"{self.prefix}-{kind}-request", - request_queue = request_queue, - response_queue = response_queue, - ) + await dispatcher.start() + self.dispatchers[key] = dispatcher - await dispatcher.start() - - self.dispatchers[key] = dispatcher - - return await dispatcher.process(data, responder) + return await self.dispatchers[key].process(data, responder) def dispatch_flow_import(self): return self.process_flow_import @@ -297,36 +296,35 @@ class DispatcherManager: key = (flow, kind) - if key in self.dispatchers: - return await self.dispatchers[key].process(data, responder) + if key not in self.dispatchers: + async with self.dispatcher_lock: + if key not in self.dispatchers: + intf_defs = self.flows[flow]["interfaces"] - intf_defs = self.flows[flow]["interfaces"] + if kind not in intf_defs: + raise RuntimeError("This kind not supported by flow") - if kind not in intf_defs: - raise RuntimeError("This kind not supported by flow") + qconfig = intf_defs[kind] - qconfig = intf_defs[kind] + if kind in request_response_dispatchers: + dispatcher = request_response_dispatchers[kind]( + backend = self.backend, + request_queue = qconfig["request"], + response_queue = qconfig["response"], + timeout = 120, + consumer = f"{self.prefix}-{flow}-{kind}-request", + subscriber = f"{self.prefix}-{flow}-{kind}-request", + ) + elif kind in sender_dispatchers: + dispatcher = sender_dispatchers[kind]( + backend = self.backend, + queue = qconfig, + ) + else: + raise RuntimeError("Invalid kind") - if kind in request_response_dispatchers: - dispatcher = request_response_dispatchers[kind]( - backend = self.backend, - request_queue = qconfig["request"], - response_queue = qconfig["response"], - timeout = 120, - consumer = f"{self.prefix}-{flow}-{kind}-request", - subscriber = f"{self.prefix}-{flow}-{kind}-request", - ) - elif kind in sender_dispatchers: - dispatcher = sender_dispatchers[kind]( - backend = self.backend, - queue = qconfig, - ) - else: - raise RuntimeError("Invalid kind") - - await dispatcher.start() + await dispatcher.start() + self.dispatchers[key] = dispatcher - self.dispatchers[key] = dispatcher - - return await dispatcher.process(data, responder) + return await self.dispatchers[key].process(data, responder) From c737e8c356ea5f3463d47878daab20078a8d8aa2 Mon Sep 17 00:00:00 2001 From: Sreeram Venkatasubramanian Date: Tue, 7 Apr 2026 16:39:20 +0530 Subject: [PATCH 3/3] fix: reduce consumer poll timeout from 2000ms to 100ms (#761) --- .../test_consumer_concurrency.py | 35 +++++++++++++++++++ trustgraph-base/trustgraph/base/consumer.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py index 32a6559b..d9dc7d5b 100644 --- a/tests/unit/test_concurrency/test_consumer_concurrency.py +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -266,6 +266,41 @@ class TestMetricsIntegration: mock_metrics.rate_limit.assert_called_once() +# --------------------------------------------------------------------------- +# Poll timeout +# --------------------------------------------------------------------------- + +class TestPollTimeout: + + @pytest.mark.asyncio + async def test_poll_timeout_is_100ms(self): + """Consumer receive timeout should be 100ms, not the original 2000ms. + + A 2000ms poll timeout means every service adds up to 2s of idle + blocking between message bursts. With many sequential hops in a + query pipeline, this compounds into seconds of unnecessary latency. + 100ms keeps responsiveness high without significant CPU overhead. + """ + consumer = _make_consumer() + + # Wire up a mock Pulsar consumer that records the receive kwargs + mock_pulsar_consumer = MagicMock() + received_kwargs = {} + + def capture_receive(**kwargs): + received_kwargs.update(kwargs) + # Stop after one call + consumer.running = False + raise type('Timeout', (Exception,), {})("timeout") + + mock_pulsar_consumer.receive = capture_receive + consumer.consumer = mock_pulsar_consumer + + await consumer.consume_from_queue() + + assert received_kwargs.get("timeout_millis") == 100 + + # --------------------------------------------------------------------------- # Stop / running flag # --------------------------------------------------------------------------- diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 2a220312..d851f3c1 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -160,7 +160,7 @@ class Consumer: try: msg = await asyncio.to_thread( self.consumer.receive, - timeout_millis=2000 + timeout_millis=100 ) except Exception as e: # Handle timeout from any backend