Merge remote-tracking branch 'origin/master' into ts-port

This commit is contained in:
elpresidank 2026-04-07 10:51:24 -05:00
commit 2a2e8e76a3
7 changed files with 229 additions and 57 deletions

View file

@ -266,6 +266,41 @@ class TestMetricsIntegration:
mock_metrics.rate_limit.assert_called_once() mock_metrics.rate_limit.assert_called_once()
# ---------------------------------------------------------------------------
# Poll timeout
# ---------------------------------------------------------------------------
class TestPollTimeout:
@pytest.mark.asyncio
async def test_poll_timeout_is_100ms(self):
"""Consumer receive timeout should be 100ms, not the original 2000ms.
A 2000ms poll timeout means every service adds up to 2s of idle
blocking between message bursts. With many sequential hops in a
query pipeline, this compounds into seconds of unnecessary latency.
100ms keeps responsiveness high without significant CPU overhead.
"""
consumer = _make_consumer()
# Wire up a mock Pulsar consumer that records the receive kwargs
mock_pulsar_consumer = MagicMock()
received_kwargs = {}
def capture_receive(**kwargs):
received_kwargs.update(kwargs)
# Stop after one call
consumer.running = False
raise type('Timeout', (Exception,), {})("timeout")
mock_pulsar_consumer.receive = capture_receive
consumer.consumer = mock_pulsar_consumer
await consumer.consume_from_queue()
assert received_kwargs.get("timeout_millis") == 100
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Stop / running flag # Stop / running flag
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -49,6 +49,7 @@ class TestDispatcherManager:
assert manager.prefix == "api-gateway" # default prefix assert manager.prefix == "api-gateway" # default prefix
assert manager.flows == {} assert manager.flows == {}
assert manager.dispatchers == {} assert manager.dispatchers == {}
assert isinstance(manager.dispatcher_lock, asyncio.Lock)
# Verify manager was added as handler to config receiver # Verify manager was added as handler to config receiver
mock_config_receiver.add_handler.assert_called_once_with(manager) mock_config_receiver.add_handler.assert_called_once_with(manager)
@ -558,3 +559,84 @@ class TestDispatcherManager:
with pytest.raises(RuntimeError, match="Invalid kind"): with pytest.raises(RuntimeError, match="Invalid kind"):
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
@pytest.mark.asyncio
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
"""Concurrent calls for the same service must create exactly one dispatcher.
Before the fix, await dispatcher.start() yielded to the event loop and
multiple coroutines could all pass the 'key not in self.dispatchers' check
before any of them wrote the result back, creating duplicate Pulsar consumers.
"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
async def slow_start():
# Yield to the event loop so other coroutines get a chance to run,
# reproducing the window that caused the original race condition.
await asyncio.sleep(0)
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
mock_dispatcher.start = AsyncMock(side_effect=slow_start)
mock_dispatcher.process = AsyncMock(return_value="result")
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
results = await asyncio.gather(*[
manager.invoke_global_service("data", "responder", "config")
for _ in range(5)
])
assert mock_dispatcher_class.call_count == 1, (
"Dispatcher class instantiated more than once — duplicate consumer bug"
)
assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[(None, "config")] is mock_dispatcher
assert all(r == "result" for r in results)
@pytest.mark.asyncio
async def test_invoke_flow_service_concurrent_calls_create_single_dispatcher(self):
"""Concurrent calls for the same flow+kind must create exactly one dispatcher.
invoke_flow_service has the same check-then-create pattern as
invoke_global_service and is protected by the same dispatcher_lock.
"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.flows["test_flow"] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
"response": "agent_response_queue",
}
}
}
async def slow_start():
await asyncio.sleep(0)
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers:
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
mock_dispatcher.start = AsyncMock(side_effect=slow_start)
mock_dispatcher.process = AsyncMock(return_value="result")
mock_dispatcher_class.return_value = mock_dispatcher
mock_rr_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_rr_dispatchers.__contains__.return_value = True
results = await asyncio.gather(*[
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
for _ in range(5)
])
assert mock_dispatcher_class.call_count == 1, (
"Dispatcher class instantiated more than once — duplicate consumer bug"
)
assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
assert all(r == "result" for r in results)

View file

@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Assert # Assert
assert processor.default_model == 'gemma2:9b' # default_model assert processor.default_model == 'granite4:350m' # default_model
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env) # Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
mock_client_class.assert_called_once() mock_client_class.assert_called_once()

View file

@ -160,7 +160,7 @@ class Consumer:
try: try:
msg = await asyncio.to_thread( msg = await asyncio.to_thread(
self.consumer.receive, self.consumer.receive,
timeout_millis=2000 timeout_millis=100
) )
except Exception as e: except Exception as e:
# Handle timeout from any backend # Handle timeout from any backend

View file

@ -7,6 +7,9 @@ from ... base import EmbeddingsService
from ollama import Client from ollama import Client
import os import os
import logging
logger = logging.getLogger(__name__)
default_ident = "embeddings" default_ident = "embeddings"
@ -29,6 +32,28 @@ class Processor(EmbeddingsService):
self.client = Client(host=ollama) self.client = Client(host=ollama)
self.default_model = model self.default_model = model
self._checked_models = set()
def _ensure_model(self, model_name):
"""Check if model exists locally, pull it if not."""
if model_name in self._checked_models:
return
try:
self.client.show(model_name)
self._checked_models.add(model_name)
except Exception as e:
status_code = getattr(e, 'status_code', None)
if status_code == 404 or "not found" in str(e).lower():
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
try:
self.client.pull(model_name)
self._checked_models.add(model_name)
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
except Exception as pull_e:
logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}")
else:
logger.warning(f"Failed to check Ollama model '{model_name}': {e}")
async def on_embeddings(self, texts, model=None): async def on_embeddings(self, texts, model=None):
@ -37,6 +62,9 @@ class Processor(EmbeddingsService):
use_model = model or self.default_model use_model = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(use_model)
# Ollama handles batch input efficiently # Ollama handles batch input efficiently
embeds = self.client.embed( embeds = self.client.embed(
model = use_model, model = use_model,

View file

@ -116,6 +116,7 @@ class DispatcherManager:
self.flows = {} self.flows = {}
self.dispatchers = {} self.dispatchers = {}
self.dispatcher_lock = asyncio.Lock()
async def start_flow(self, id, flow): async def start_flow(self, id, flow):
logger.info(f"Starting flow {id}") logger.info(f"Starting flow {id}")
@ -163,30 +164,28 @@ class DispatcherManager:
key = (None, kind) key = (None, kind)
if key in self.dispatchers: if key not in self.dispatchers:
return await self.dispatchers[key].process(data, responder) async with self.dispatcher_lock:
if key not in self.dispatchers:
request_queue = None
response_queue = None
if kind in self.queue_overrides:
request_queue = self.queue_overrides[kind].get("request")
response_queue = self.queue_overrides[kind].get("response")
# Get queue overrides if specified for this service dispatcher = global_dispatchers[kind](
request_queue = None backend = self.backend,
response_queue = None timeout = 120,
if kind in self.queue_overrides: consumer = f"{self.prefix}-{kind}-request",
request_queue = self.queue_overrides[kind].get("request") subscriber = f"{self.prefix}-{kind}-request",
response_queue = self.queue_overrides[kind].get("response") request_queue = request_queue,
response_queue = response_queue,
)
dispatcher = global_dispatchers[kind]( await dispatcher.start()
backend = self.backend, self.dispatchers[key] = dispatcher
timeout = 120,
consumer = f"{self.prefix}-{kind}-request",
subscriber = f"{self.prefix}-{kind}-request",
request_queue = request_queue,
response_queue = response_queue,
)
await dispatcher.start() return await self.dispatchers[key].process(data, responder)
self.dispatchers[key] = dispatcher
return await dispatcher.process(data, responder)
def dispatch_flow_import(self): def dispatch_flow_import(self):
return self.process_flow_import return self.process_flow_import
@ -297,36 +296,35 @@ class DispatcherManager:
key = (flow, kind) key = (flow, kind)
if key in self.dispatchers: if key not in self.dispatchers:
return await self.dispatchers[key].process(data, responder) async with self.dispatcher_lock:
if key not in self.dispatchers:
intf_defs = self.flows[flow]["interfaces"]
intf_defs = self.flows[flow]["interfaces"] if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
if kind not in intf_defs: qconfig = intf_defs[kind]
raise RuntimeError("This kind not supported by flow")
qconfig = intf_defs[kind] if kind in request_response_dispatchers:
dispatcher = request_response_dispatchers[kind](
backend = self.backend,
request_queue = qconfig["request"],
response_queue = qconfig["response"],
timeout = 120,
consumer = f"{self.prefix}-{flow}-{kind}-request",
subscriber = f"{self.prefix}-{flow}-{kind}-request",
)
elif kind in sender_dispatchers:
dispatcher = sender_dispatchers[kind](
backend = self.backend,
queue = qconfig,
)
else:
raise RuntimeError("Invalid kind")
if kind in request_response_dispatchers: await dispatcher.start()
dispatcher = request_response_dispatchers[kind]( self.dispatchers[key] = dispatcher
backend = self.backend,
request_queue = qconfig["request"],
response_queue = qconfig["response"],
timeout = 120,
consumer = f"{self.prefix}-{flow}-{kind}-request",
subscriber = f"{self.prefix}-{flow}-{kind}-request",
)
elif kind in sender_dispatchers:
dispatcher = sender_dispatchers[kind](
backend = self.backend,
queue = qconfig,
)
else:
raise RuntimeError("Invalid kind")
await dispatcher.start() return await self.dispatchers[key].process(data, responder)
self.dispatchers[key] = dispatcher
return await dispatcher.process(data, responder)

View file

@ -16,7 +16,7 @@ from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion" default_ident = "text-completion"
default_model = 'gemma2:9b' default_model = 'granite4:350m'
default_temperature = 0.0 default_temperature = 0.0
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
@ -39,11 +39,36 @@ class Processor(LlmService):
self.default_model = model self.default_model = model
self.temperature = temperature self.temperature = temperature
self.llm = Client(host=ollama) self.llm = Client(host=ollama)
self._checked_models = set()
def _ensure_model(self, model_name):
"""Check if model exists locally, pull it if not."""
if model_name in self._checked_models:
return
try:
self.llm.show(model_name)
self._checked_models.add(model_name)
except Exception as e:
status_code = getattr(e, 'status_code', None)
if status_code == 404 or "not found" in str(e).lower():
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
try:
self.llm.pull(model_name)
self._checked_models.add(model_name)
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
except Exception as pull_e:
logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}")
else:
logger.warning(f"Failed to check Ollama model '{model_name}': {e}")
async def generate_content(self, system, prompt, model=None, temperature=None): async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default # Use provided model or fall back to default
model_name = model or self.default_model model_name = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(model_name)
# Use provided temperature or fall back to default # Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature effective_temperature = temperature if temperature is not None else self.temperature
@ -86,6 +111,10 @@ class Processor(LlmService):
async def generate_content_stream(self, system, prompt, model=None, temperature=None): async def generate_content_stream(self, system, prompt, model=None, temperature=None):
"""Stream content generation from Ollama""" """Stream content generation from Ollama"""
model_name = model or self.default_model model_name = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(model_name)
effective_temperature = temperature if temperature is not None else self.temperature effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model (streaming): {model_name}") logger.debug(f"Using model (streaming): {model_name}")
@ -142,7 +171,7 @@ class Processor(LlmService):
parser.add_argument( parser.add_argument(
'-m', '--model', '-m', '--model',
default="gemma2", default="granite4:350m",
help=f'LLM model (default: {default_model})' help=f'LLM model (default: {default_model})'
) )