trustgraph/tests/unit/test_embeddings/test_fastembed_dynamic_model.py
cybermaggedon d9d4c91363
Dynamic embeddings model (#556)
* Dynamic embeddings model selection

* Added tests

* HF embeddings are skipped, tests don't run with that package currently tests
2025-11-10 20:38:01 +00:00

216 lines
10 KiB
Python

"""
Unit tests for FastEmbed dynamic model loading
Tests the model caching and dynamic loading functionality for FastEmbed
embeddings service.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from unittest import IsolatedAsyncioTestCase
from trustgraph.embeddings.fastembed.processor import Processor
class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test FastEmbed dynamic model loading and caching"""
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that default model is loaded during initialization"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
base_params = {
"id": "test-embeddings",
"concurrency": 1,
"model": "test-model",
"taskgroup": AsyncMock()
}
# Act
processor = Processor(**base_params)
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="test-model")
assert processor.default_model == "test-model"
assert processor.cached_model_name == "test-model"
assert processor.embeddings is not None
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that using the same model doesn't reload it"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act - use same model multiple times
processor._load_model("test-model")
processor._load_model("test-model")
processor._load_model("test-model")
# Assert - model should not be reloaded
mock_text_embedding_class.assert_not_called()
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that changing model name triggers reload"""
# Arrange
mock_fastembed_instance = Mock()
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act - switch to different model
processor._load_model("different-model")
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="different-model")
assert processor.cached_model_name == "different-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act
result = await processor.on_embeddings("test text")
# Assert
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act
result = await processor.on_embeddings("test text", model="custom-model")
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
assert processor.cached_model_name == "custom-model"
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test switching between multiple models"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_call_count = mock_text_embedding_class.call_count
# Act - switch between models
await processor.on_embeddings("text1", model="model-a")
call_count_after_a = mock_text_embedding_class.call_count
await processor.on_embeddings("text2", model="model-a") # Same, no reload
call_count_after_a_repeat = mock_text_embedding_class.call_count
await processor.on_embeddings("text3", model="model-b") # Different, reload
call_count_after_b = mock_text_embedding_class.call_count
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
call_count_after_a_again = mock_text_embedding_class.call_count
# Assert
assert call_count_after_a == initial_call_count + 1 # First load
assert call_count_after_a_repeat == initial_call_count + 1 # No reload
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_count = mock_text_embedding_class.call_count
# Act
result = await processor.on_embeddings("test text", model=None)
# Assert
# No reload, using cached default
assert mock_text_embedding_class.call_count == initial_count
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_fastembed_instance = Mock()
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
# Assert
# Should use default_model from module
expected_default = "sentence-transformers/all-MiniLM-L6-v2"
mock_text_embedding_class.assert_called_once_with(model_name=expected_default)
assert processor.default_model == expected_default
assert processor.cached_model_name == expected_default
if __name__ == '__main__':
pytest.main([__file__])