Fix Ollama async issue (#854)

* Fix Ollama sync issues - replaced with async

* Fix tests
This commit is contained in:
cybermaggedon 2026-04-28 15:43:04 +01:00 committed by Cyber MacGeddon
parent fad005e030
commit a8fdf547db
4 changed files with 54 additions and 54 deletions

View file

@ -14,13 +14,13 @@ from trustgraph.embeddings.ollama.processor import Processor
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test Ollama dynamic model selection""" """Test Ollama dynamic model selection"""
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class): async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized with correct host""" """Test that Ollama client is initialized with correct host"""
# Arrange # Arrange
mock_ollama_client = Mock() mock_ollama_client = AsyncMock()
mock_response = Mock() mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response mock_ollama_client.embed.return_value = mock_response
@ -36,13 +36,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
mock_client_class.assert_called_once_with(host="http://localhost:11434") mock_client_class.assert_called_once_with(host="http://localhost:11434")
assert processor.default_model == "test-model" assert processor.default_model == "test-model"
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class): async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses default model when no model specified""" """Test that on_embeddings uses default model when no model specified"""
# Arrange # Arrange
mock_ollama_client = Mock() mock_ollama_client = AsyncMock()
mock_response = Mock() mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response mock_ollama_client.embed.return_value = mock_response
@ -62,13 +62,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
) )
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class): async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses specified model when provided""" """Test that on_embeddings uses specified model when provided"""
# Arrange # Arrange
mock_ollama_client = Mock() mock_ollama_client = AsyncMock()
mock_response = Mock() mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response mock_ollama_client.embed.return_value = mock_response
@ -88,13 +88,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
) )
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class): async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test switching between multiple models""" """Test switching between multiple models"""
# Arrange # Arrange
mock_ollama_client = Mock() mock_ollama_client = AsyncMock()
mock_response = Mock() mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response mock_ollama_client.embed.return_value = mock_response
@ -118,13 +118,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
assert calls[2][1]['model'] == "model-a" assert calls[2][1]['model'] == "model-a"
assert calls[3][1]['model'] == "test-model" # Default assert calls[3][1]['model'] == "test-model" # Default
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class): async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that None model parameter falls back to default""" """Test that None model parameter falls back to default"""
# Arrange # Arrange
mock_ollama_client = Mock() mock_ollama_client = AsyncMock()
mock_response = Mock() mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response mock_ollama_client.embed.return_value = mock_response
@ -143,13 +143,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
input=["test text"] input=["test text"]
) )
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class): async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test initialization without model parameter uses module default""" """Test initialization without model parameter uses module default"""
# Arrange # Arrange
mock_ollama_client = Mock() mock_ollama_client = AsyncMock()
mock_client_class.return_value = mock_ollama_client mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None mock_async_init.return_value = None
mock_embeddings_init.return_value = None mock_embeddings_init.return_value = None

View file

@ -15,13 +15,13 @@ from trustgraph.base import LlmResult
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
"""Test Ollama processor functionality""" """Test Ollama processor functionality"""
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class): async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test basic processor initialization""" """Test basic processor initialization"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
# Mock the parent class initialization # Mock the parent class initialization
@ -44,13 +44,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert hasattr(processor, 'llm') assert hasattr(processor, 'llm')
mock_client_class.assert_called_once_with(host='http://localhost:11434') mock_client_class.assert_called_once_with(host='http://localhost:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class): async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test successful content generation""" """Test successful content generation"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_response = { mock_response = {
'response': 'Generated response from Ollama', 'response': 'Generated response from Ollama',
'prompt_eval_count': 15, 'prompt_eval_count': 15,
@ -83,13 +83,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.model == 'llama2' assert result.model == 'llama2'
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0}) mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class): async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test handling of generic exceptions""" """Test handling of generic exceptions"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_client.generate.side_effect = Exception("Connection error") mock_client.generate.side_effect = Exception("Connection error")
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
@ -110,13 +110,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(Exception, match="Connection error"): with pytest.raises(Exception, match="Connection error"):
await processor.generate_content("System prompt", "User prompt") await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class): async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test processor initialization with custom parameters""" """Test processor initialization with custom parameters"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
mock_async_init.return_value = None mock_async_init.return_value = None
@ -137,13 +137,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'mistral' assert processor.default_model == 'mistral'
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434') mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class): async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test processor initialization with default values""" """Test processor initialization with default values"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
mock_async_init.return_value = None mock_async_init.return_value = None
@ -164,13 +164,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env) # Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
mock_client_class.assert_called_once() mock_client_class.assert_called_once()
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class): async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test content generation with empty prompts""" """Test content generation with empty prompts"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_response = { mock_response = {
'response': 'Default response', 'response': 'Default response',
'prompt_eval_count': 2, 'prompt_eval_count': 2,
@ -205,13 +205,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# The prompt should be "" + "\n\n" + "" = "\n\n" # The prompt should be "" + "\n\n" + "" = "\n\n"
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0}) mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class): async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test token counting from Ollama response""" """Test token counting from Ollama response"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_response = { mock_response = {
'response': 'Test response', 'response': 'Test response',
'prompt_eval_count': 50, 'prompt_eval_count': 50,
@ -243,13 +243,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.out_token == 25 assert result.out_token == 25
assert result.model == 'llama2' assert result.model == 'llama2'
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class): async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized correctly""" """Test that Ollama client is initialized correctly"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
mock_async_init.return_value = None mock_async_init.return_value = None
@ -273,13 +273,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client # Verify processor has the client
assert processor.llm == mock_client assert processor.llm == mock_client
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class): async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test prompt construction with system and user prompts""" """Test prompt construction with system and user prompts"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_response = { mock_response = {
'response': 'Response with system instructions', 'response': 'Response with system instructions',
'prompt_eval_count': 25, 'prompt_eval_count': 25,
@ -312,13 +312,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Verify the combined prompt # Verify the combined prompt
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0}) mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class): async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test temperature parameter override functionality""" """Test temperature parameter override functionality"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_response = { mock_response = {
'response': 'Response with custom temperature', 'response': 'Response with custom temperature',
'prompt_eval_count': 20, 'prompt_eval_count': 20,
@ -360,13 +360,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
options={'temperature': 0.8} # Should use runtime override options={'temperature': 0.8} # Should use runtime override
) )
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class): async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test model parameter override functionality""" """Test model parameter override functionality"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_response = { mock_response = {
'response': 'Response with custom model', 'response': 'Response with custom model',
'prompt_eval_count': 18, 'prompt_eval_count': 18,
@ -408,13 +408,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
options={'temperature': 0.1} # Should use processor default options={'temperature': 0.1} # Should use processor default
) )
@patch('trustgraph.model.text_completion.ollama.llm.Client') @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class): async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test overriding both model and temperature parameters simultaneously""" """Test overriding both model and temperature parameters simultaneously"""
# Arrange # Arrange
mock_client = MagicMock() mock_client = AsyncMock()
mock_response = { mock_response = {
'response': 'Response with both overrides', 'response': 'Response with both overrides',
'prompt_eval_count': 22, 'prompt_eval_count': 22,

View file

@ -5,7 +5,7 @@ Input is text, output is embeddings vector.
""" """
from ... base import EmbeddingsService from ... base import EmbeddingsService
from ollama import Client from ollama import AsyncClient
import os import os
import logging import logging
@ -30,24 +30,24 @@ class Processor(EmbeddingsService):
} }
) )
self.client = Client(host=ollama) self.client = AsyncClient(host=ollama)
self.default_model = model self.default_model = model
self._checked_models = set() self._checked_models = set()
def _ensure_model(self, model_name): async def _ensure_model(self, model_name):
"""Check if model exists locally, pull it if not.""" """Check if model exists locally, pull it if not."""
if model_name in self._checked_models: if model_name in self._checked_models:
return return
try: try:
self.client.show(model_name) await self.client.show(model_name)
self._checked_models.add(model_name) self._checked_models.add(model_name)
except Exception as e: except Exception as e:
status_code = getattr(e, 'status_code', None) status_code = getattr(e, 'status_code', None)
if status_code == 404 or "not found" in str(e).lower(): if status_code == 404 or "not found" in str(e).lower():
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
try: try:
self.client.pull(model_name) await self.client.pull(model_name)
self._checked_models.add(model_name) self._checked_models.add(model_name)
logger.info(f"Successfully pulled Ollama model '{model_name}'.") logger.info(f"Successfully pulled Ollama model '{model_name}'.")
except Exception as pull_e: except Exception as pull_e:
@ -63,10 +63,10 @@ class Processor(EmbeddingsService):
use_model = model or self.default_model use_model = model or self.default_model
# Ensure the model exists/is pulled # Ensure the model exists/is pulled
self._ensure_model(use_model) await self._ensure_model(use_model)
# Ollama handles batch input efficiently # Ollama handles batch input efficiently
embeds = self.client.embed( embeds = await self.client.embed(
model = use_model, model = use_model,
input = texts input = texts
) )

View file

@ -4,7 +4,7 @@ Simple LLM service, performs text prompt completion using an Ollama service.
Input is prompt, output is response. Input is prompt, output is response.
""" """
from ollama import Client from ollama import AsyncClient
import os import os
import logging import logging
@ -38,23 +38,23 @@ class Processor(LlmService):
self.default_model = model self.default_model = model
self.temperature = temperature self.temperature = temperature
self.llm = Client(host=ollama) self.llm = AsyncClient(host=ollama)
self._checked_models = set() self._checked_models = set()
def _ensure_model(self, model_name): async def _ensure_model(self, model_name):
"""Check if model exists locally, pull it if not.""" """Check if model exists locally, pull it if not."""
if model_name in self._checked_models: if model_name in self._checked_models:
return return
try: try:
self.llm.show(model_name) await self.llm.show(model_name)
self._checked_models.add(model_name) self._checked_models.add(model_name)
except Exception as e: except Exception as e:
status_code = getattr(e, 'status_code', None) status_code = getattr(e, 'status_code', None)
if status_code == 404 or "not found" in str(e).lower(): if status_code == 404 or "not found" in str(e).lower():
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
try: try:
self.llm.pull(model_name) await self.llm.pull(model_name)
self._checked_models.add(model_name) self._checked_models.add(model_name)
logger.info(f"Successfully pulled Ollama model '{model_name}'.") logger.info(f"Successfully pulled Ollama model '{model_name}'.")
except Exception as pull_e: except Exception as pull_e:
@ -68,7 +68,7 @@ class Processor(LlmService):
model_name = model or self.default_model model_name = model or self.default_model
# Ensure the model exists/is pulled # Ensure the model exists/is pulled
self._ensure_model(model_name) await self._ensure_model(model_name)
# Use provided temperature or fall back to default # Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature effective_temperature = temperature if temperature is not None else self.temperature
@ -79,7 +79,7 @@ class Processor(LlmService):
try: try:
response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature}) response = await self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
response_text = response['response'] response_text = response['response']
logger.debug("Sending response...") logger.debug("Sending response...")
@ -113,7 +113,7 @@ class Processor(LlmService):
model_name = model or self.default_model model_name = model or self.default_model
# Ensure the model exists/is pulled # Ensure the model exists/is pulled
self._ensure_model(model_name) await self._ensure_model(model_name)
effective_temperature = temperature if temperature is not None else self.temperature effective_temperature = temperature if temperature is not None else self.temperature
@ -123,7 +123,7 @@ class Processor(LlmService):
prompt = system + "\n\n" + prompt prompt = system + "\n\n" + prompt
try: try:
stream = self.llm.generate( stream = await self.llm.generate(
model_name, model_name,
prompt, prompt,
options={'temperature': effective_temperature}, options={'temperature': effective_temperature},
@ -133,7 +133,7 @@ class Processor(LlmService):
total_input_tokens = 0 total_input_tokens = 0
total_output_tokens = 0 total_output_tokens = 0
for chunk in stream: async for chunk in stream:
if 'response' in chunk and chunk['response']: if 'response' in chunk and chunk['response']:
yield LlmChunk( yield LlmChunk(
text=chunk['response'], text=chunk['response'],