Feat: Auto-pull missing Ollama models (#757)

* fix deadlink in readme

Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>

* feat: Auto-pull Ollama models

Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>

* fix: Restore namespace __init__.py files for package resolution

Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>

* fix CI

Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>
This commit is contained in:
Alex Jenkins 2026-04-06 10:10:14 +00:00 committed by GitHub
parent be443a1679
commit 7daa06e9e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 60 additions and 3 deletions

View file

@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Assert # Assert
assert processor.default_model == 'gemma2:9b' # default_model assert processor.default_model == 'granite4:350m' # default_model
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env) # Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
mock_client_class.assert_called_once() mock_client_class.assert_called_once()

View file

@ -7,6 +7,9 @@ from ... base import EmbeddingsService
from ollama import Client from ollama import Client
import os import os
import logging
logger = logging.getLogger(__name__)
default_ident = "embeddings" default_ident = "embeddings"
@ -29,6 +32,28 @@ class Processor(EmbeddingsService):
self.client = Client(host=ollama) self.client = Client(host=ollama)
self.default_model = model self.default_model = model
self._checked_models = set()
def _ensure_model(self, model_name):
"""Check if model exists locally, pull it if not."""
if model_name in self._checked_models:
return
try:
self.client.show(model_name)
self._checked_models.add(model_name)
except Exception as e:
status_code = getattr(e, 'status_code', None)
if status_code == 404 or "not found" in str(e).lower():
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
try:
self.client.pull(model_name)
self._checked_models.add(model_name)
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
except Exception as pull_e:
logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}")
else:
logger.warning(f"Failed to check Ollama model '{model_name}': {e}")
async def on_embeddings(self, texts, model=None): async def on_embeddings(self, texts, model=None):
@ -37,6 +62,9 @@ class Processor(EmbeddingsService):
use_model = model or self.default_model use_model = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(use_model)
# Ollama handles batch input efficiently # Ollama handles batch input efficiently
embeds = self.client.embed( embeds = self.client.embed(
model = use_model, model = use_model,

View file

@ -16,7 +16,7 @@ from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion" default_ident = "text-completion"
default_model = 'gemma2:9b' default_model = 'granite4:350m'
default_temperature = 0.0 default_temperature = 0.0
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
@ -39,11 +39,36 @@ class Processor(LlmService):
self.default_model = model self.default_model = model
self.temperature = temperature self.temperature = temperature
self.llm = Client(host=ollama) self.llm = Client(host=ollama)
self._checked_models = set()
def _ensure_model(self, model_name):
"""Check if model exists locally, pull it if not."""
if model_name in self._checked_models:
return
try:
self.llm.show(model_name)
self._checked_models.add(model_name)
except Exception as e:
status_code = getattr(e, 'status_code', None)
if status_code == 404 or "not found" in str(e).lower():
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
try:
self.llm.pull(model_name)
self._checked_models.add(model_name)
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
except Exception as pull_e:
logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}")
else:
logger.warning(f"Failed to check Ollama model '{model_name}': {e}")
async def generate_content(self, system, prompt, model=None, temperature=None): async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default # Use provided model or fall back to default
model_name = model or self.default_model model_name = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(model_name)
# Use provided temperature or fall back to default # Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature effective_temperature = temperature if temperature is not None else self.temperature
@ -86,6 +111,10 @@ class Processor(LlmService):
async def generate_content_stream(self, system, prompt, model=None, temperature=None): async def generate_content_stream(self, system, prompt, model=None, temperature=None):
"""Stream content generation from Ollama""" """Stream content generation from Ollama"""
model_name = model or self.default_model model_name = model or self.default_model
# Ensure the model exists/is pulled
self._ensure_model(model_name)
effective_temperature = temperature if temperature is not None else self.temperature effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model (streaming): {model_name}") logger.debug(f"Using model (streaming): {model_name}")
@ -142,7 +171,7 @@ class Processor(LlmService):
parser.add_argument( parser.add_argument(
'-m', '--model', '-m', '--model',
default="gemma2", default="granite4:350m",
help=f'LLM model (default: {default_model})' help=f'LLM model (default: {default_model})'
) )