mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
Dynamic embeddings model (#556)
* Dynamic embeddings model selection * Added tests * HF embeddings are skipped, tests don't run with that package currently tests
This commit is contained in:
parent
6129bb68c1
commit
d9d4c91363
8 changed files with 816 additions and 11 deletions
167
tests/unit/test_embeddings/test_ollama_dynamic_model.py
Normal file
167
tests/unit/test_embeddings/test_ollama_dynamic_model.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""
|
||||
Unit tests for Ollama dynamic model loading
|
||||
|
||||
Tests the dynamic model selection functionality for Ollama embeddings service.
|
||||
Since Ollama is server-side, no model caching is needed on the client side.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from trustgraph.embeddings.ollama.processor import Processor
|
||||
|
||||
|
||||
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||
"""Test Ollama dynamic model selection"""
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that Ollama client is initialized with correct host"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
# Act
|
||||
processor = Processor(id="test", concurrency=1, model="test-model",
|
||||
ollama="http://localhost:11434", taskgroup=AsyncMock())
|
||||
|
||||
# Assert
|
||||
mock_client_class.assert_called_once_with(host="http://localhost:11434")
|
||||
assert processor.default_model == "test-model"
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that on_embeddings uses default model when no model specified"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that on_embeddings uses specified model when provided"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="custom-model",
|
||||
input="test text"
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test switching between multiple models"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act - switch between different models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
await processor.on_embeddings("text2", model="model-b")
|
||||
await processor.on_embeddings("text3", model="model-a")
|
||||
await processor.on_embeddings("text4") # Use default
|
||||
|
||||
# Assert
|
||||
calls = mock_ollama_client.embed.call_args_list
|
||||
assert len(calls) == 4
|
||||
assert calls[0][1]['model'] == "model-a"
|
||||
assert calls[1][1]['model'] == "model-b"
|
||||
assert calls[2][1]['model'] == "model-a"
|
||||
assert calls[3][1]['model'] == "test-model" # Default
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that None model parameter falls back to default"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
)
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test initialization without model parameter uses module default"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
# Act
|
||||
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
|
||||
|
||||
# Assert
|
||||
# Should use default_model from module
|
||||
expected_default = "mxbai-embed-large"
|
||||
assert processor.default_model == expected_default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue