mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Feat: Auto-pull missing Ollama models (#757)
* fix deadlink in readme Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu> * feat: Auto-pull Ollama models Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu> * fix: Restore namespace __init__.py files for package resolution Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu> * fix CI Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>
This commit is contained in:
parent
be443a1679
commit
7daa06e9e4
3 changed files with 60 additions and 3 deletions
|
|
@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert processor.default_model == 'gemma2:9b' # default_model
|
assert processor.default_model == 'granite4:350m' # default_model
|
||||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||||
mock_client_class.assert_called_once()
|
mock_client_class.assert_called_once()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,9 @@ from ... base import EmbeddingsService
|
||||||
|
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "embeddings"
|
default_ident = "embeddings"
|
||||||
|
|
||||||
|
|
@ -29,6 +32,28 @@ class Processor(EmbeddingsService):
|
||||||
|
|
||||||
self.client = Client(host=ollama)
|
self.client = Client(host=ollama)
|
||||||
self.default_model = model
|
self.default_model = model
|
||||||
|
self._checked_models = set()
|
||||||
|
|
||||||
|
def _ensure_model(self, model_name):
|
||||||
|
"""Check if model exists locally, pull it if not."""
|
||||||
|
if model_name in self._checked_models:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.show(model_name)
|
||||||
|
self._checked_models.add(model_name)
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, 'status_code', None)
|
||||||
|
if status_code == 404 or "not found" in str(e).lower():
|
||||||
|
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
||||||
|
try:
|
||||||
|
self.client.pull(model_name)
|
||||||
|
self._checked_models.add(model_name)
|
||||||
|
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
||||||
|
except Exception as pull_e:
|
||||||
|
logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to check Ollama model '{model_name}': {e}")
|
||||||
|
|
||||||
async def on_embeddings(self, texts, model=None):
|
async def on_embeddings(self, texts, model=None):
|
||||||
|
|
||||||
|
|
@ -37,6 +62,9 @@ class Processor(EmbeddingsService):
|
||||||
|
|
||||||
use_model = model or self.default_model
|
use_model = model or self.default_model
|
||||||
|
|
||||||
|
# Ensure the model exists/is pulled
|
||||||
|
self._ensure_model(use_model)
|
||||||
|
|
||||||
# Ollama handles batch input efficiently
|
# Ollama handles batch input efficiently
|
||||||
embeds = self.client.embed(
|
embeds = self.client.embed(
|
||||||
model = use_model,
|
model = use_model,
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from .... base import LlmService, LlmResult, LlmChunk
|
||||||
|
|
||||||
default_ident = "text-completion"
|
default_ident = "text-completion"
|
||||||
|
|
||||||
default_model = 'gemma2:9b'
|
default_model = 'granite4:350m'
|
||||||
default_temperature = 0.0
|
default_temperature = 0.0
|
||||||
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
|
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
|
||||||
|
|
||||||
|
|
@ -39,11 +39,36 @@ class Processor(LlmService):
|
||||||
self.default_model = model
|
self.default_model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.llm = Client(host=ollama)
|
self.llm = Client(host=ollama)
|
||||||
|
self._checked_models = set()
|
||||||
|
|
||||||
|
def _ensure_model(self, model_name):
|
||||||
|
"""Check if model exists locally, pull it if not."""
|
||||||
|
if model_name in self._checked_models:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.llm.show(model_name)
|
||||||
|
self._checked_models.add(model_name)
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, 'status_code', None)
|
||||||
|
if status_code == 404 or "not found" in str(e).lower():
|
||||||
|
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
||||||
|
try:
|
||||||
|
self.llm.pull(model_name)
|
||||||
|
self._checked_models.add(model_name)
|
||||||
|
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
||||||
|
except Exception as pull_e:
|
||||||
|
logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to check Ollama model '{model_name}': {e}")
|
||||||
|
|
||||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||||
|
|
||||||
# Use provided model or fall back to default
|
# Use provided model or fall back to default
|
||||||
model_name = model or self.default_model
|
model_name = model or self.default_model
|
||||||
|
|
||||||
|
# Ensure the model exists/is pulled
|
||||||
|
self._ensure_model(model_name)
|
||||||
# Use provided temperature or fall back to default
|
# Use provided temperature or fall back to default
|
||||||
effective_temperature = temperature if temperature is not None else self.temperature
|
effective_temperature = temperature if temperature is not None else self.temperature
|
||||||
|
|
||||||
|
|
@ -86,6 +111,10 @@ class Processor(LlmService):
|
||||||
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
||||||
"""Stream content generation from Ollama"""
|
"""Stream content generation from Ollama"""
|
||||||
model_name = model or self.default_model
|
model_name = model or self.default_model
|
||||||
|
|
||||||
|
# Ensure the model exists/is pulled
|
||||||
|
self._ensure_model(model_name)
|
||||||
|
|
||||||
effective_temperature = temperature if temperature is not None else self.temperature
|
effective_temperature = temperature if temperature is not None else self.temperature
|
||||||
|
|
||||||
logger.debug(f"Using model (streaming): {model_name}")
|
logger.debug(f"Using model (streaming): {model_name}")
|
||||||
|
|
@ -142,7 +171,7 @@ class Processor(LlmService):
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-m', '--model',
|
'-m', '--model',
|
||||||
default="gemma2",
|
default="granite4:350m",
|
||||||
help=f'LLM model (default: {default_model})'
|
help=f'LLM model (default: {default_model})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue