mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
Merge remote-tracking branch 'origin/master' into ts-port
This commit is contained in:
commit
2a2e8e76a3
7 changed files with 229 additions and 57 deletions
|
|
@ -266,6 +266,41 @@ class TestMetricsIntegration:
|
||||||
mock_metrics.rate_limit.assert_called_once()
|
mock_metrics.rate_limit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Poll timeout
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPollTimeout:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_timeout_is_100ms(self):
|
||||||
|
"""Consumer receive timeout should be 100ms, not the original 2000ms.
|
||||||
|
|
||||||
|
A 2000ms poll timeout means every service adds up to 2s of idle
|
||||||
|
blocking between message bursts. With many sequential hops in a
|
||||||
|
query pipeline, this compounds into seconds of unnecessary latency.
|
||||||
|
100ms keeps responsiveness high without significant CPU overhead.
|
||||||
|
"""
|
||||||
|
consumer = _make_consumer()
|
||||||
|
|
||||||
|
# Wire up a mock Pulsar consumer that records the receive kwargs
|
||||||
|
mock_pulsar_consumer = MagicMock()
|
||||||
|
received_kwargs = {}
|
||||||
|
|
||||||
|
def capture_receive(**kwargs):
|
||||||
|
received_kwargs.update(kwargs)
|
||||||
|
# Stop after one call
|
||||||
|
consumer.running = False
|
||||||
|
raise type('Timeout', (Exception,), {})("timeout")
|
||||||
|
|
||||||
|
mock_pulsar_consumer.receive = capture_receive
|
||||||
|
consumer.consumer = mock_pulsar_consumer
|
||||||
|
|
||||||
|
await consumer.consume_from_queue()
|
||||||
|
|
||||||
|
assert received_kwargs.get("timeout_millis") == 100
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Stop / running flag
|
# Stop / running flag
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ class TestDispatcherManager:
|
||||||
assert manager.prefix == "api-gateway" # default prefix
|
assert manager.prefix == "api-gateway" # default prefix
|
||||||
assert manager.flows == {}
|
assert manager.flows == {}
|
||||||
assert manager.dispatchers == {}
|
assert manager.dispatchers == {}
|
||||||
|
assert isinstance(manager.dispatcher_lock, asyncio.Lock)
|
||||||
|
|
||||||
# Verify manager was added as handler to config receiver
|
# Verify manager was added as handler to config receiver
|
||||||
mock_config_receiver.add_handler.assert_called_once_with(manager)
|
mock_config_receiver.add_handler.assert_called_once_with(manager)
|
||||||
|
|
@ -558,3 +559,84 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||||
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
|
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
|
||||||
|
"""Concurrent calls for the same service must create exactly one dispatcher.
|
||||||
|
|
||||||
|
Before the fix, await dispatcher.start() yielded to the event loop and
|
||||||
|
multiple coroutines could all pass the 'key not in self.dispatchers' check
|
||||||
|
before any of them wrote the result back, creating duplicate Pulsar consumers.
|
||||||
|
"""
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_config_receiver = Mock()
|
||||||
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
|
async def slow_start():
|
||||||
|
# Yield to the event loop so other coroutines get a chance to run,
|
||||||
|
# reproducing the window that caused the original race condition.
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
|
||||||
|
mock_dispatcher_class = Mock()
|
||||||
|
mock_dispatcher = Mock()
|
||||||
|
mock_dispatcher.start = AsyncMock(side_effect=slow_start)
|
||||||
|
mock_dispatcher.process = AsyncMock(return_value="result")
|
||||||
|
mock_dispatcher_class.return_value = mock_dispatcher
|
||||||
|
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||||
|
|
||||||
|
results = await asyncio.gather(*[
|
||||||
|
manager.invoke_global_service("data", "responder", "config")
|
||||||
|
for _ in range(5)
|
||||||
|
])
|
||||||
|
|
||||||
|
assert mock_dispatcher_class.call_count == 1, (
|
||||||
|
"Dispatcher class instantiated more than once — duplicate consumer bug"
|
||||||
|
)
|
||||||
|
assert mock_dispatcher.start.call_count == 1
|
||||||
|
assert manager.dispatchers[(None, "config")] is mock_dispatcher
|
||||||
|
assert all(r == "result" for r in results)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_flow_service_concurrent_calls_create_single_dispatcher(self):
|
||||||
|
"""Concurrent calls for the same flow+kind must create exactly one dispatcher.
|
||||||
|
|
||||||
|
invoke_flow_service has the same check-then-create pattern as
|
||||||
|
invoke_global_service and is protected by the same dispatcher_lock.
|
||||||
|
"""
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_config_receiver = Mock()
|
||||||
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
|
manager.flows["test_flow"] = {
|
||||||
|
"interfaces": {
|
||||||
|
"agent": {
|
||||||
|
"request": "agent_request_queue",
|
||||||
|
"response": "agent_response_queue",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def slow_start():
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers:
|
||||||
|
mock_dispatcher_class = Mock()
|
||||||
|
mock_dispatcher = Mock()
|
||||||
|
mock_dispatcher.start = AsyncMock(side_effect=slow_start)
|
||||||
|
mock_dispatcher.process = AsyncMock(return_value="result")
|
||||||
|
mock_dispatcher_class.return_value = mock_dispatcher
|
||||||
|
mock_rr_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||||
|
mock_rr_dispatchers.__contains__.return_value = True
|
||||||
|
|
||||||
|
results = await asyncio.gather(*[
|
||||||
|
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||||
|
for _ in range(5)
|
||||||
|
])
|
||||||
|
|
||||||
|
assert mock_dispatcher_class.call_count == 1, (
|
||||||
|
"Dispatcher class instantiated more than once — duplicate consumer bug"
|
||||||
|
)
|
||||||
|
assert mock_dispatcher.start.call_count == 1
|
||||||
|
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
|
||||||
|
assert all(r == "result" for r in results)
|
||||||
|
|
@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert processor.default_model == 'gemma2:9b' # default_model
|
assert processor.default_model == 'granite4:350m' # default_model
|
||||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||||
mock_client_class.assert_called_once()
|
mock_client_class.assert_called_once()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ class Consumer:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.to_thread(
|
msg = await asyncio.to_thread(
|
||||||
self.consumer.receive,
|
self.consumer.receive,
|
||||||
timeout_millis=2000
|
timeout_millis=100
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle timeout from any backend
|
# Handle timeout from any backend
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,9 @@ from ... base import EmbeddingsService
|
||||||
|
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "embeddings"
|
default_ident = "embeddings"
|
||||||
|
|
||||||
|
|
@ -29,6 +32,28 @@ class Processor(EmbeddingsService):
|
||||||
|
|
||||||
self.client = Client(host=ollama)
|
self.client = Client(host=ollama)
|
||||||
self.default_model = model
|
self.default_model = model
|
||||||
|
self._checked_models = set()
|
||||||
|
|
||||||
|
def _ensure_model(self, model_name):
|
||||||
|
"""Check if model exists locally, pull it if not."""
|
||||||
|
if model_name in self._checked_models:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.show(model_name)
|
||||||
|
self._checked_models.add(model_name)
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, 'status_code', None)
|
||||||
|
if status_code == 404 or "not found" in str(e).lower():
|
||||||
|
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
||||||
|
try:
|
||||||
|
self.client.pull(model_name)
|
||||||
|
self._checked_models.add(model_name)
|
||||||
|
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
||||||
|
except Exception as pull_e:
|
||||||
|
logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to check Ollama model '{model_name}': {e}")
|
||||||
|
|
||||||
async def on_embeddings(self, texts, model=None):
|
async def on_embeddings(self, texts, model=None):
|
||||||
|
|
||||||
|
|
@ -37,6 +62,9 @@ class Processor(EmbeddingsService):
|
||||||
|
|
||||||
use_model = model or self.default_model
|
use_model = model or self.default_model
|
||||||
|
|
||||||
|
# Ensure the model exists/is pulled
|
||||||
|
self._ensure_model(use_model)
|
||||||
|
|
||||||
# Ollama handles batch input efficiently
|
# Ollama handles batch input efficiently
|
||||||
embeds = self.client.embed(
|
embeds = self.client.embed(
|
||||||
model = use_model,
|
model = use_model,
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,7 @@ class DispatcherManager:
|
||||||
|
|
||||||
self.flows = {}
|
self.flows = {}
|
||||||
self.dispatchers = {}
|
self.dispatchers = {}
|
||||||
|
self.dispatcher_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def start_flow(self, id, flow):
|
async def start_flow(self, id, flow):
|
||||||
logger.info(f"Starting flow {id}")
|
logger.info(f"Starting flow {id}")
|
||||||
|
|
@ -163,30 +164,28 @@ class DispatcherManager:
|
||||||
|
|
||||||
key = (None, kind)
|
key = (None, kind)
|
||||||
|
|
||||||
if key in self.dispatchers:
|
if key not in self.dispatchers:
|
||||||
return await self.dispatchers[key].process(data, responder)
|
async with self.dispatcher_lock:
|
||||||
|
if key not in self.dispatchers:
|
||||||
|
request_queue = None
|
||||||
|
response_queue = None
|
||||||
|
if kind in self.queue_overrides:
|
||||||
|
request_queue = self.queue_overrides[kind].get("request")
|
||||||
|
response_queue = self.queue_overrides[kind].get("response")
|
||||||
|
|
||||||
# Get queue overrides if specified for this service
|
dispatcher = global_dispatchers[kind](
|
||||||
request_queue = None
|
backend = self.backend,
|
||||||
response_queue = None
|
timeout = 120,
|
||||||
if kind in self.queue_overrides:
|
consumer = f"{self.prefix}-{kind}-request",
|
||||||
request_queue = self.queue_overrides[kind].get("request")
|
subscriber = f"{self.prefix}-{kind}-request",
|
||||||
response_queue = self.queue_overrides[kind].get("response")
|
request_queue = request_queue,
|
||||||
|
response_queue = response_queue,
|
||||||
|
)
|
||||||
|
|
||||||
dispatcher = global_dispatchers[kind](
|
await dispatcher.start()
|
||||||
backend = self.backend,
|
self.dispatchers[key] = dispatcher
|
||||||
timeout = 120,
|
|
||||||
consumer = f"{self.prefix}-{kind}-request",
|
|
||||||
subscriber = f"{self.prefix}-{kind}-request",
|
|
||||||
request_queue = request_queue,
|
|
||||||
response_queue = response_queue,
|
|
||||||
)
|
|
||||||
|
|
||||||
await dispatcher.start()
|
return await self.dispatchers[key].process(data, responder)
|
||||||
|
|
||||||
self.dispatchers[key] = dispatcher
|
|
||||||
|
|
||||||
return await dispatcher.process(data, responder)
|
|
||||||
|
|
||||||
def dispatch_flow_import(self):
|
def dispatch_flow_import(self):
|
||||||
return self.process_flow_import
|
return self.process_flow_import
|
||||||
|
|
@ -297,36 +296,35 @@ class DispatcherManager:
|
||||||
|
|
||||||
key = (flow, kind)
|
key = (flow, kind)
|
||||||
|
|
||||||
if key in self.dispatchers:
|
if key not in self.dispatchers:
|
||||||
return await self.dispatchers[key].process(data, responder)
|
async with self.dispatcher_lock:
|
||||||
|
if key not in self.dispatchers:
|
||||||
|
intf_defs = self.flows[flow]["interfaces"]
|
||||||
|
|
||||||
intf_defs = self.flows[flow]["interfaces"]
|
if kind not in intf_defs:
|
||||||
|
raise RuntimeError("This kind not supported by flow")
|
||||||
|
|
||||||
if kind not in intf_defs:
|
qconfig = intf_defs[kind]
|
||||||
raise RuntimeError("This kind not supported by flow")
|
|
||||||
|
|
||||||
qconfig = intf_defs[kind]
|
if kind in request_response_dispatchers:
|
||||||
|
dispatcher = request_response_dispatchers[kind](
|
||||||
|
backend = self.backend,
|
||||||
|
request_queue = qconfig["request"],
|
||||||
|
response_queue = qconfig["response"],
|
||||||
|
timeout = 120,
|
||||||
|
consumer = f"{self.prefix}-{flow}-{kind}-request",
|
||||||
|
subscriber = f"{self.prefix}-{flow}-{kind}-request",
|
||||||
|
)
|
||||||
|
elif kind in sender_dispatchers:
|
||||||
|
dispatcher = sender_dispatchers[kind](
|
||||||
|
backend = self.backend,
|
||||||
|
queue = qconfig,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid kind")
|
||||||
|
|
||||||
if kind in request_response_dispatchers:
|
await dispatcher.start()
|
||||||
dispatcher = request_response_dispatchers[kind](
|
self.dispatchers[key] = dispatcher
|
||||||
backend = self.backend,
|
|
||||||
request_queue = qconfig["request"],
|
|
||||||
response_queue = qconfig["response"],
|
|
||||||
timeout = 120,
|
|
||||||
consumer = f"{self.prefix}-{flow}-{kind}-request",
|
|
||||||
subscriber = f"{self.prefix}-{flow}-{kind}-request",
|
|
||||||
)
|
|
||||||
elif kind in sender_dispatchers:
|
|
||||||
dispatcher = sender_dispatchers[kind](
|
|
||||||
backend = self.backend,
|
|
||||||
queue = qconfig,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Invalid kind")
|
|
||||||
|
|
||||||
await dispatcher.start()
|
return await self.dispatchers[key].process(data, responder)
|
||||||
|
|
||||||
self.dispatchers[key] = dispatcher
|
|
||||||
|
|
||||||
return await dispatcher.process(data, responder)
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from .... base import LlmService, LlmResult, LlmChunk
|
||||||
|
|
||||||
default_ident = "text-completion"
|
default_ident = "text-completion"
|
||||||
|
|
||||||
default_model = 'gemma2:9b'
|
default_model = 'granite4:350m'
|
||||||
default_temperature = 0.0
|
default_temperature = 0.0
|
||||||
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
|
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
|
||||||
|
|
||||||
|
|
@ -39,11 +39,36 @@ class Processor(LlmService):
|
||||||
self.default_model = model
|
self.default_model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.llm = Client(host=ollama)
|
self.llm = Client(host=ollama)
|
||||||
|
self._checked_models = set()
|
||||||
|
|
||||||
|
def _ensure_model(self, model_name):
|
||||||
|
"""Check if model exists locally, pull it if not."""
|
||||||
|
if model_name in self._checked_models:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.llm.show(model_name)
|
||||||
|
self._checked_models.add(model_name)
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, 'status_code', None)
|
||||||
|
if status_code == 404 or "not found" in str(e).lower():
|
||||||
|
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
||||||
|
try:
|
||||||
|
self.llm.pull(model_name)
|
||||||
|
self._checked_models.add(model_name)
|
||||||
|
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
||||||
|
except Exception as pull_e:
|
||||||
|
logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to check Ollama model '{model_name}': {e}")
|
||||||
|
|
||||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||||
|
|
||||||
# Use provided model or fall back to default
|
# Use provided model or fall back to default
|
||||||
model_name = model or self.default_model
|
model_name = model or self.default_model
|
||||||
|
|
||||||
|
# Ensure the model exists/is pulled
|
||||||
|
self._ensure_model(model_name)
|
||||||
# Use provided temperature or fall back to default
|
# Use provided temperature or fall back to default
|
||||||
effective_temperature = temperature if temperature is not None else self.temperature
|
effective_temperature = temperature if temperature is not None else self.temperature
|
||||||
|
|
||||||
|
|
@ -86,6 +111,10 @@ class Processor(LlmService):
|
||||||
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
||||||
"""Stream content generation from Ollama"""
|
"""Stream content generation from Ollama"""
|
||||||
model_name = model or self.default_model
|
model_name = model or self.default_model
|
||||||
|
|
||||||
|
# Ensure the model exists/is pulled
|
||||||
|
self._ensure_model(model_name)
|
||||||
|
|
||||||
effective_temperature = temperature if temperature is not None else self.temperature
|
effective_temperature = temperature if temperature is not None else self.temperature
|
||||||
|
|
||||||
logger.debug(f"Using model (streaming): {model_name}")
|
logger.debug(f"Using model (streaming): {model_name}")
|
||||||
|
|
@ -142,7 +171,7 @@ class Processor(LlmService):
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-m', '--model',
|
'-m', '--model',
|
||||||
default="gemma2",
|
default="granite4:350m",
|
||||||
help=f'LLM model (default: {default_model})'
|
help=f'LLM model (default: {default_model})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue