mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 10:56:23 +02:00
Fix Ollama async issue (#854)
* Fix Ollama sync issues - replaced with async * Fix tests
This commit is contained in:
parent
fad005e030
commit
a8fdf547db
4 changed files with 54 additions and 54 deletions
|
|
@ -14,13 +14,13 @@ from trustgraph.embeddings.ollama.processor import Processor
|
||||||
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
"""Test Ollama dynamic model selection"""
|
"""Test Ollama dynamic model selection"""
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||||
"""Test that Ollama client is initialized with correct host"""
|
"""Test that Ollama client is initialized with correct host"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_ollama_client = Mock()
|
mock_ollama_client = AsyncMock()
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
mock_ollama_client.embed.return_value = mock_response
|
mock_ollama_client.embed.return_value = mock_response
|
||||||
|
|
@ -36,13 +36,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
mock_client_class.assert_called_once_with(host="http://localhost:11434")
|
mock_client_class.assert_called_once_with(host="http://localhost:11434")
|
||||||
assert processor.default_model == "test-model"
|
assert processor.default_model == "test-model"
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||||
"""Test that on_embeddings uses default model when no model specified"""
|
"""Test that on_embeddings uses default model when no model specified"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_ollama_client = Mock()
|
mock_ollama_client = AsyncMock()
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
mock_ollama_client.embed.return_value = mock_response
|
mock_ollama_client.embed.return_value = mock_response
|
||||||
|
|
@ -62,13 +62,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
)
|
)
|
||||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||||
"""Test that on_embeddings uses specified model when provided"""
|
"""Test that on_embeddings uses specified model when provided"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_ollama_client = Mock()
|
mock_ollama_client = AsyncMock()
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
mock_ollama_client.embed.return_value = mock_response
|
mock_ollama_client.embed.return_value = mock_response
|
||||||
|
|
@ -88,13 +88,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
)
|
)
|
||||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||||
"""Test switching between multiple models"""
|
"""Test switching between multiple models"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_ollama_client = Mock()
|
mock_ollama_client = AsyncMock()
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
mock_ollama_client.embed.return_value = mock_response
|
mock_ollama_client.embed.return_value = mock_response
|
||||||
|
|
@ -118,13 +118,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
assert calls[2][1]['model'] == "model-a"
|
assert calls[2][1]['model'] == "model-a"
|
||||||
assert calls[3][1]['model'] == "test-model" # Default
|
assert calls[3][1]['model'] == "test-model" # Default
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||||
"""Test that None model parameter falls back to default"""
|
"""Test that None model parameter falls back to default"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_ollama_client = Mock()
|
mock_ollama_client = AsyncMock()
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
mock_ollama_client.embed.return_value = mock_response
|
mock_ollama_client.embed.return_value = mock_response
|
||||||
|
|
@ -143,13 +143,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
input=["test text"]
|
input=["test text"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||||
"""Test initialization without model parameter uses module default"""
|
"""Test initialization without model parameter uses module default"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_ollama_client = Mock()
|
mock_ollama_client = AsyncMock()
|
||||||
mock_client_class.return_value = mock_ollama_client
|
mock_client_class.return_value = mock_ollama_client
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_embeddings_init.return_value = None
|
mock_embeddings_init.return_value = None
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,13 @@ from trustgraph.base import LlmResult
|
||||||
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
"""Test Ollama processor functionality"""
|
"""Test Ollama processor functionality"""
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test basic processor initialization"""
|
"""Test basic processor initialization"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
# Mock the parent class initialization
|
# Mock the parent class initialization
|
||||||
|
|
@ -44,13 +44,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert hasattr(processor, 'llm')
|
assert hasattr(processor, 'llm')
|
||||||
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test successful content generation"""
|
"""Test successful content generation"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_response = {
|
mock_response = {
|
||||||
'response': 'Generated response from Ollama',
|
'response': 'Generated response from Ollama',
|
||||||
'prompt_eval_count': 15,
|
'prompt_eval_count': 15,
|
||||||
|
|
@ -83,13 +83,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert result.model == 'llama2'
|
assert result.model == 'llama2'
|
||||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
|
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test handling of generic exceptions"""
|
"""Test handling of generic exceptions"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_client.generate.side_effect = Exception("Connection error")
|
mock_client.generate.side_effect = Exception("Connection error")
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
|
@ -110,13 +110,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
with pytest.raises(Exception, match="Connection error"):
|
with pytest.raises(Exception, match="Connection error"):
|
||||||
await processor.generate_content("System prompt", "User prompt")
|
await processor.generate_content("System prompt", "User prompt")
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test processor initialization with custom parameters"""
|
"""Test processor initialization with custom parameters"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
|
|
@ -137,13 +137,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert processor.default_model == 'mistral'
|
assert processor.default_model == 'mistral'
|
||||||
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test processor initialization with default values"""
|
"""Test processor initialization with default values"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
|
|
@ -164,13 +164,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||||
mock_client_class.assert_called_once()
|
mock_client_class.assert_called_once()
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test content generation with empty prompts"""
|
"""Test content generation with empty prompts"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_response = {
|
mock_response = {
|
||||||
'response': 'Default response',
|
'response': 'Default response',
|
||||||
'prompt_eval_count': 2,
|
'prompt_eval_count': 2,
|
||||||
|
|
@ -205,13 +205,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||||
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
|
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test token counting from Ollama response"""
|
"""Test token counting from Ollama response"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_response = {
|
mock_response = {
|
||||||
'response': 'Test response',
|
'response': 'Test response',
|
||||||
'prompt_eval_count': 50,
|
'prompt_eval_count': 50,
|
||||||
|
|
@ -243,13 +243,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert result.out_token == 25
|
assert result.out_token == 25
|
||||||
assert result.model == 'llama2'
|
assert result.model == 'llama2'
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test that Ollama client is initialized correctly"""
|
"""Test that Ollama client is initialized correctly"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
|
|
@ -273,13 +273,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
# Verify processor has the client
|
# Verify processor has the client
|
||||||
assert processor.llm == mock_client
|
assert processor.llm == mock_client
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test prompt construction with system and user prompts"""
|
"""Test prompt construction with system and user prompts"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_response = {
|
mock_response = {
|
||||||
'response': 'Response with system instructions',
|
'response': 'Response with system instructions',
|
||||||
'prompt_eval_count': 25,
|
'prompt_eval_count': 25,
|
||||||
|
|
@ -312,13 +312,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
# Verify the combined prompt
|
# Verify the combined prompt
|
||||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
|
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test temperature parameter override functionality"""
|
"""Test temperature parameter override functionality"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_response = {
|
mock_response = {
|
||||||
'response': 'Response with custom temperature',
|
'response': 'Response with custom temperature',
|
||||||
'prompt_eval_count': 20,
|
'prompt_eval_count': 20,
|
||||||
|
|
@ -360,13 +360,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
options={'temperature': 0.8} # Should use runtime override
|
options={'temperature': 0.8} # Should use runtime override
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test model parameter override functionality"""
|
"""Test model parameter override functionality"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_response = {
|
mock_response = {
|
||||||
'response': 'Response with custom model',
|
'response': 'Response with custom model',
|
||||||
'prompt_eval_count': 18,
|
'prompt_eval_count': 18,
|
||||||
|
|
@ -408,13 +408,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
options={'temperature': 0.1} # Should use processor default
|
options={'temperature': 0.1} # Should use processor default
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||||
"""Test overriding both model and temperature parameters simultaneously"""
|
"""Test overriding both model and temperature parameters simultaneously"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client = MagicMock()
|
mock_client = AsyncMock()
|
||||||
mock_response = {
|
mock_response = {
|
||||||
'response': 'Response with both overrides',
|
'response': 'Response with both overrides',
|
||||||
'prompt_eval_count': 22,
|
'prompt_eval_count': 22,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ Input is text, output is embeddings vector.
|
||||||
"""
|
"""
|
||||||
from ... base import EmbeddingsService
|
from ... base import EmbeddingsService
|
||||||
|
|
||||||
from ollama import Client
|
from ollama import AsyncClient
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
@ -30,24 +30,24 @@ class Processor(EmbeddingsService):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client = Client(host=ollama)
|
self.client = AsyncClient(host=ollama)
|
||||||
self.default_model = model
|
self.default_model = model
|
||||||
self._checked_models = set()
|
self._checked_models = set()
|
||||||
|
|
||||||
def _ensure_model(self, model_name):
|
async def _ensure_model(self, model_name):
|
||||||
"""Check if model exists locally, pull it if not."""
|
"""Check if model exists locally, pull it if not."""
|
||||||
if model_name in self._checked_models:
|
if model_name in self._checked_models:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client.show(model_name)
|
await self.client.show(model_name)
|
||||||
self._checked_models.add(model_name)
|
self._checked_models.add(model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status_code = getattr(e, 'status_code', None)
|
status_code = getattr(e, 'status_code', None)
|
||||||
if status_code == 404 or "not found" in str(e).lower():
|
if status_code == 404 or "not found" in str(e).lower():
|
||||||
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
||||||
try:
|
try:
|
||||||
self.client.pull(model_name)
|
await self.client.pull(model_name)
|
||||||
self._checked_models.add(model_name)
|
self._checked_models.add(model_name)
|
||||||
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
||||||
except Exception as pull_e:
|
except Exception as pull_e:
|
||||||
|
|
@ -63,10 +63,10 @@ class Processor(EmbeddingsService):
|
||||||
use_model = model or self.default_model
|
use_model = model or self.default_model
|
||||||
|
|
||||||
# Ensure the model exists/is pulled
|
# Ensure the model exists/is pulled
|
||||||
self._ensure_model(use_model)
|
await self._ensure_model(use_model)
|
||||||
|
|
||||||
# Ollama handles batch input efficiently
|
# Ollama handles batch input efficiently
|
||||||
embeds = self.client.embed(
|
embeds = await self.client.embed(
|
||||||
model = use_model,
|
model = use_model,
|
||||||
input = texts
|
input = texts
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ Simple LLM service, performs text prompt completion using an Ollama service.
|
||||||
Input is prompt, output is response.
|
Input is prompt, output is response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from ollama import Client
|
from ollama import AsyncClient
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
@ -38,23 +38,23 @@ class Processor(LlmService):
|
||||||
|
|
||||||
self.default_model = model
|
self.default_model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.llm = Client(host=ollama)
|
self.llm = AsyncClient(host=ollama)
|
||||||
self._checked_models = set()
|
self._checked_models = set()
|
||||||
|
|
||||||
def _ensure_model(self, model_name):
|
async def _ensure_model(self, model_name):
|
||||||
"""Check if model exists locally, pull it if not."""
|
"""Check if model exists locally, pull it if not."""
|
||||||
if model_name in self._checked_models:
|
if model_name in self._checked_models:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.llm.show(model_name)
|
await self.llm.show(model_name)
|
||||||
self._checked_models.add(model_name)
|
self._checked_models.add(model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status_code = getattr(e, 'status_code', None)
|
status_code = getattr(e, 'status_code', None)
|
||||||
if status_code == 404 or "not found" in str(e).lower():
|
if status_code == 404 or "not found" in str(e).lower():
|
||||||
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
||||||
try:
|
try:
|
||||||
self.llm.pull(model_name)
|
await self.llm.pull(model_name)
|
||||||
self._checked_models.add(model_name)
|
self._checked_models.add(model_name)
|
||||||
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
||||||
except Exception as pull_e:
|
except Exception as pull_e:
|
||||||
|
|
@ -68,7 +68,7 @@ class Processor(LlmService):
|
||||||
model_name = model or self.default_model
|
model_name = model or self.default_model
|
||||||
|
|
||||||
# Ensure the model exists/is pulled
|
# Ensure the model exists/is pulled
|
||||||
self._ensure_model(model_name)
|
await self._ensure_model(model_name)
|
||||||
# Use provided temperature or fall back to default
|
# Use provided temperature or fall back to default
|
||||||
effective_temperature = temperature if temperature is not None else self.temperature
|
effective_temperature = temperature if temperature is not None else self.temperature
|
||||||
|
|
||||||
|
|
@ -79,7 +79,7 @@ class Processor(LlmService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
|
response = await self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
|
||||||
|
|
||||||
response_text = response['response']
|
response_text = response['response']
|
||||||
logger.debug("Sending response...")
|
logger.debug("Sending response...")
|
||||||
|
|
@ -113,7 +113,7 @@ class Processor(LlmService):
|
||||||
model_name = model or self.default_model
|
model_name = model or self.default_model
|
||||||
|
|
||||||
# Ensure the model exists/is pulled
|
# Ensure the model exists/is pulled
|
||||||
self._ensure_model(model_name)
|
await self._ensure_model(model_name)
|
||||||
|
|
||||||
effective_temperature = temperature if temperature is not None else self.temperature
|
effective_temperature = temperature if temperature is not None else self.temperature
|
||||||
|
|
||||||
|
|
@ -123,7 +123,7 @@ class Processor(LlmService):
|
||||||
prompt = system + "\n\n" + prompt
|
prompt = system + "\n\n" + prompt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = self.llm.generate(
|
stream = await self.llm.generate(
|
||||||
model_name,
|
model_name,
|
||||||
prompt,
|
prompt,
|
||||||
options={'temperature': effective_temperature},
|
options={'temperature': effective_temperature},
|
||||||
|
|
@ -133,7 +133,7 @@ class Processor(LlmService):
|
||||||
total_input_tokens = 0
|
total_input_tokens = 0
|
||||||
total_output_tokens = 0
|
total_output_tokens = 0
|
||||||
|
|
||||||
for chunk in stream:
|
async for chunk in stream:
|
||||||
if 'response' in chunk and chunk['response']:
|
if 'response' in chunk and chunk['response']:
|
||||||
yield LlmChunk(
|
yield LlmChunk(
|
||||||
text=chunk['response'],
|
text=chunk['response'],
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue