Fix Ollama async issue (#854)

* Fix Ollama sync issues - replaced with async

* Fix tests
This commit is contained in:
cybermaggedon 2026-04-28 15:43:04 +01:00 committed by Cyber MacGeddon
parent fad005e030
commit a8fdf547db
4 changed files with 54 additions and 54 deletions

View file

@ -14,13 +14,13 @@ from trustgraph.embeddings.ollama.processor import Processor
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test Ollama dynamic model selection"""
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized with correct host"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -36,13 +36,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
mock_client_class.assert_called_once_with(host="http://localhost:11434")
assert processor.default_model == "test-model"
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -62,13 +62,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -88,13 +88,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test switching between multiple models"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -118,13 +118,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
assert calls[2][1]['model'] == "model-a"
assert calls[3][1]['model'] == "test-model" # Default
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -143,13 +143,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
input=["test text"]
)
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None

View file

@ -15,13 +15,13 @@ from trustgraph.base import LlmResult
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
"""Test Ollama processor functionality"""
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test basic processor initialization"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
# Mock the parent class initialization
@ -44,13 +44,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert hasattr(processor, 'llm')
mock_client_class.assert_called_once_with(host='http://localhost:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test successful content generation"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Generated response from Ollama',
'prompt_eval_count': 15,
@ -83,13 +83,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.model == 'llama2'
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test handling of generic exceptions"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client.generate.side_effect = Exception("Connection error")
mock_client_class.return_value = mock_client
@ -110,13 +110,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(Exception, match="Connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
@ -137,13 +137,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'mistral'
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test processor initialization with default values"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
@ -164,13 +164,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
mock_client_class.assert_called_once()
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test content generation with empty prompts"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Default response',
'prompt_eval_count': 2,
@ -205,13 +205,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# The prompt should be "" + "\n\n" + "" = "\n\n"
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test token counting from Ollama response"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Test response',
'prompt_eval_count': 50,
@ -243,13 +243,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.out_token == 25
assert result.model == 'llama2'
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized correctly"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
@ -273,13 +273,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client
assert processor.llm == mock_client
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test prompt construction with system and user prompts"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Response with system instructions',
'prompt_eval_count': 25,
@ -312,13 +312,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Verify the combined prompt
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Response with custom temperature',
'prompt_eval_count': 20,
@ -360,13 +360,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
options={'temperature': 0.8} # Should use runtime override
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test model parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Response with custom model',
'prompt_eval_count': 18,
@ -408,13 +408,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
options={'temperature': 0.1} # Should use processor default
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Response with both overrides',
'prompt_eval_count': 22,

View file

@ -5,7 +5,7 @@ Input is text, output is embeddings vector.
"""
from ... base import EmbeddingsService
from ollama import Client
from ollama import AsyncClient
import os
import logging
@ -30,24 +30,24 @@ class Processor(EmbeddingsService):
}
)
self.client = Client(host=ollama)
self.client = AsyncClient(host=ollama)
self.default_model = model
self._checked_models = set()
def _ensure_model(self, model_name):
async def _ensure_model(self, model_name):
"""Check if model exists locally, pull it if not."""
if model_name in self._checked_models:
return
try:
self.client.show(model_name)
await self.client.show(model_name)
self._checked_models.add(model_name)
except Exception as e:
status_code = getattr(e, 'status_code', None)
if status_code == 404 or "not found" in str(e).lower():
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
try:
self.client.pull(model_name)
await self.client.pull(model_name)
self._checked_models.add(model_name)
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
except Exception as pull_e:
@ -63,10 +63,10 @@ class Processor(EmbeddingsService):
use_model = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(use_model)
await self._ensure_model(use_model)
# Ollama handles batch input efficiently
embeds = self.client.embed(
embeds = await self.client.embed(
model = use_model,
input = texts
)

View file

@ -4,7 +4,7 @@ Simple LLM service, performs text prompt completion using an Ollama service.
Input is prompt, output is response.
"""
from ollama import Client
from ollama import AsyncClient
import os
import logging
@ -38,23 +38,23 @@ class Processor(LlmService):
self.default_model = model
self.temperature = temperature
self.llm = Client(host=ollama)
self.llm = AsyncClient(host=ollama)
self._checked_models = set()
def _ensure_model(self, model_name):
async def _ensure_model(self, model_name):
"""Check if model exists locally, pull it if not."""
if model_name in self._checked_models:
return
try:
self.llm.show(model_name)
await self.llm.show(model_name)
self._checked_models.add(model_name)
except Exception as e:
status_code = getattr(e, 'status_code', None)
if status_code == 404 or "not found" in str(e).lower():
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
try:
self.llm.pull(model_name)
await self.llm.pull(model_name)
self._checked_models.add(model_name)
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
except Exception as pull_e:
@ -66,9 +66,9 @@ class Processor(LlmService):
# Use provided model or fall back to default
model_name = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(model_name)
await self._ensure_model(model_name)
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
@ -79,7 +79,7 @@ class Processor(LlmService):
try:
response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
response = await self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
response_text = response['response']
logger.debug("Sending response...")
@ -113,7 +113,7 @@ class Processor(LlmService):
model_name = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(model_name)
await self._ensure_model(model_name)
effective_temperature = temperature if temperature is not None else self.temperature
@ -123,7 +123,7 @@ class Processor(LlmService):
prompt = system + "\n\n" + prompt
try:
stream = self.llm.generate(
stream = await self.llm.generate(
model_name,
prompt,
options={'temperature': effective_temperature},
@ -133,7 +133,7 @@ class Processor(LlmService):
total_input_tokens = 0
total_output_tokens = 0
for chunk in stream:
async for chunk in stream:
if 'response' in chunk and chunk['response']:
yield LlmChunk(
text=chunk['response'],