diff --git a/tests/unit/test_text_completion/test_ollama_processor.py b/tests/unit/test_text_completion/test_ollama_processor.py index 0bf5e0ab..69baf85f 100644 --- a/tests/unit/test_text_completion/test_ollama_processor.py +++ b/tests/unit/test_text_completion/test_ollama_processor.py @@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - assert processor.default_model == 'gemma2:9b' # default_model + assert processor.default_model == 'granite4:350m' # default_model # Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env) mock_client_class.assert_called_once() diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index a65b4ff7..c63db33c 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -7,6 +7,9 @@ from ... base import EmbeddingsService from ollama import Client import os +import logging + +logger = logging.getLogger(__name__) default_ident = "embeddings" @@ -29,6 +32,28 @@ class Processor(EmbeddingsService): self.client = Client(host=ollama) self.default_model = model + self._checked_models = set() + + def _ensure_model(self, model_name): + """Check if model exists locally, pull it if not.""" + if model_name in self._checked_models: + return + + try: + self.client.show(model_name) + self._checked_models.add(model_name) + except Exception as e: + status_code = getattr(e, 'status_code', None) + if status_code == 404 or "not found" in str(e).lower(): + logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") + try: + self.client.pull(model_name) + self._checked_models.add(model_name) + logger.info(f"Successfully pulled Ollama model '{model_name}'.") + except Exception as pull_e: + logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}") + else: + logger.warning(f"Failed to check Ollama model '{model_name}': {e}") async def on_embeddings(self, texts, model=None): @@ -37,6 +62,9 @@ class Processor(EmbeddingsService): use_model = model or self.default_model + # Ensure the model exists/is pulled + self._ensure_model(use_model) + # Ollama handles batch input efficiently embeds = self.client.embed( model = use_model, diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index 3616e428..f6c5dcb8 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -16,7 +16,7 @@ from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" -default_model = 'gemma2:9b' +default_model = 'granite4:350m' default_temperature = 0.0 default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') @@ -39,11 +39,36 @@ class Processor(LlmService): self.default_model = model self.temperature = temperature self.llm = Client(host=ollama) + self._checked_models = set() + + def _ensure_model(self, model_name): + """Check if model exists locally, pull it if not.""" + if model_name in self._checked_models: + return + + try: + self.llm.show(model_name) + self._checked_models.add(model_name) + except Exception as e: + status_code = getattr(e, 'status_code', None) + if status_code == 404 or "not found" in str(e).lower(): + logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") + try: + self.llm.pull(model_name) + self._checked_models.add(model_name) + logger.info(f"Successfully pulled Ollama model '{model_name}'.") + except Exception as pull_e: + logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}") + else: + logger.warning(f"Failed to check Ollama model '{model_name}': {e}") async def generate_content(self, system, prompt, model=None, temperature=None): # Use provided model or fall back to default model_name = model or self.default_model + + # Ensure the model exists/is pulled + self._ensure_model(model_name) # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature @@ -86,6 +111,10 @@ class Processor(LlmService): async def generate_content_stream(self, system, prompt, model=None, temperature=None): """Stream content generation from Ollama""" model_name = model or self.default_model + + # Ensure the model exists/is pulled + self._ensure_model(model_name) + effective_temperature = temperature if temperature is not None else self.temperature logger.debug(f"Using model (streaming): {model_name}") @@ -142,7 +171,7 @@ class Processor(LlmService): parser.add_argument( '-m', '--model', - default="gemma2", + default="granite4:350m", help=f'LLM model (default: {default_model})' )