2025-11-10 20:38:01 +00:00
|
|
|
"""
|
|
|
|
|
Unit tests for FastEmbed dynamic model loading
|
|
|
|
|
|
|
|
|
|
Tests the model caching and dynamic loading functionality for FastEmbed
|
|
|
|
|
embeddings service.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
|
|
|
from unittest import IsolatedAsyncioTestCase
|
|
|
|
|
from trustgraph.embeddings.fastembed.processor import Processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|
|
|
|
"""Test FastEmbed dynamic model loading and caching"""
|
|
|
|
|
|
|
|
|
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
|
|
|
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
|
|
|
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
|
|
|
|
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
|
|
|
|
"""Test that default model is loaded during initialization"""
|
|
|
|
|
# Arrange
|
|
|
|
|
mock_fastembed_instance = Mock()
|
|
|
|
|
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
|
|
|
|
mock_text_embedding_class.return_value = mock_fastembed_instance
|
|
|
|
|
mock_async_init.return_value = None
|
|
|
|
|
mock_embeddings_init.return_value = None
|
|
|
|
|
|
|
|
|
|
base_params = {
|
|
|
|
|
"id": "test-embeddings",
|
|
|
|
|
"concurrency": 1,
|
|
|
|
|
"model": "test-model",
|
|
|
|
|
"taskgroup": AsyncMock()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Act
|
|
|
|
|
processor = Processor(**base_params)
|
|
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
|
mock_text_embedding_class.assert_called_once_with(model_name="test-model")
|
|
|
|
|
assert processor.default_model == "test-model"
|
|
|
|
|
assert processor.cached_model_name == "test-model"
|
|
|
|
|
assert processor.embeddings is not None
|
|
|
|
|
|
|
|
|
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
|
|
|
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
|
|
|
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
|
|
|
|
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
|
|
|
|
"""Test that using the same model doesn't reload it"""
|
|
|
|
|
# Arrange
|
|
|
|
|
mock_fastembed_instance = Mock()
|
|
|
|
|
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
|
|
|
|
mock_text_embedding_class.return_value = mock_fastembed_instance
|
|
|
|
|
mock_async_init.return_value = None
|
|
|
|
|
mock_embeddings_init.return_value = None
|
|
|
|
|
|
|
|
|
|
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
|
|
|
|
mock_text_embedding_class.reset_mock()
|
|
|
|
|
|
|
|
|
|
# Act - use same model multiple times
|
|
|
|
|
processor._load_model("test-model")
|
|
|
|
|
processor._load_model("test-model")
|
|
|
|
|
processor._load_model("test-model")
|
|
|
|
|
|
|
|
|
|
# Assert - model should not be reloaded
|
|
|
|
|
mock_text_embedding_class.assert_not_called()
|
|
|
|
|
assert processor.cached_model_name == "test-model"
|
|
|
|
|
|
|
|
|
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
|
|
|
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
|
|
|
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
|
|
|
|
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
|
|
|
|
"""Test that changing model name triggers reload"""
|
|
|
|
|
# Arrange
|
|
|
|
|
mock_fastembed_instance = Mock()
|
|
|
|
|
mock_text_embedding_class.return_value = mock_fastembed_instance
|
|
|
|
|
mock_async_init.return_value = None
|
|
|
|
|
mock_embeddings_init.return_value = None
|
|
|
|
|
|
|
|
|
|
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
|
|
|
|
mock_text_embedding_class.reset_mock()
|
|
|
|
|
|
|
|
|
|
# Act - switch to different model
|
|
|
|
|
processor._load_model("different-model")
|
|
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
|
mock_text_embedding_class.assert_called_once_with(model_name="different-model")
|
|
|
|
|
assert processor.cached_model_name == "different-model"
|
|
|
|
|
|
|
|
|
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
|
|
|
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
|
|
|
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
|
|
|
|
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
|
|
|
|
"""Test that on_embeddings uses default model when no model specified"""
|
|
|
|
|
# Arrange
|
|
|
|
|
mock_fastembed_instance = Mock()
|
|
|
|
|
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
|
|
|
|
mock_text_embedding_class.return_value = mock_fastembed_instance
|
|
|
|
|
mock_async_init.return_value = None
|
|
|
|
|
mock_embeddings_init.return_value = None
|
|
|
|
|
|
|
|
|
|
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
|
|
|
|
mock_text_embedding_class.reset_mock()
|
|
|
|
|
|
|
|
|
|
# Act
|
2026-03-08 18:36:54 +00:00
|
|
|
result = await processor.on_embeddings(["test text"])
|
2025-11-10 20:38:01 +00:00
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
|
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
|
|
|
|
assert processor.cached_model_name == "test-model" # Still using default
|
2026-03-09 10:53:44 +00:00
|
|
|
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
2025-11-10 20:38:01 +00:00
|
|
|
|
|
|
|
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
|
|
|
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
|
|
|
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
|
|
|
|
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
|
|
|
|
"""Test that on_embeddings uses specified model when provided"""
|
|
|
|
|
# Arrange
|
|
|
|
|
mock_fastembed_instance = Mock()
|
|
|
|
|
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
|
|
|
|
mock_text_embedding_class.return_value = mock_fastembed_instance
|
|
|
|
|
mock_async_init.return_value = None
|
|
|
|
|
mock_embeddings_init.return_value = None
|
|
|
|
|
|
|
|
|
|
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
|
|
|
|
mock_text_embedding_class.reset_mock()
|
|
|
|
|
|
|
|
|
|
# Act
|
2026-03-08 18:36:54 +00:00
|
|
|
result = await processor.on_embeddings(["test text"], model="custom-model")
|
2025-11-10 20:38:01 +00:00
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
|
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
|
|
|
|
|
assert processor.cached_model_name == "custom-model"
|
|
|
|
|
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
|
|
|
|
|
|
|
|
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
|
|
|
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
|
|
|
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
|
|
|
|
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
|
|
|
|
"""Test switching between multiple models"""
|
|
|
|
|
# Arrange
|
|
|
|
|
mock_fastembed_instance = Mock()
|
|
|
|
|
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
|
|
|
|
mock_text_embedding_class.return_value = mock_fastembed_instance
|
|
|
|
|
mock_async_init.return_value = None
|
|
|
|
|
mock_embeddings_init.return_value = None
|
|
|
|
|
|
|
|
|
|
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
|
|
|
|
initial_call_count = mock_text_embedding_class.call_count
|
|
|
|
|
|
|
|
|
|
# Act - switch between models
|
2026-03-08 18:36:54 +00:00
|
|
|
await processor.on_embeddings(["text1"], model="model-a")
|
2025-11-10 20:38:01 +00:00
|
|
|
call_count_after_a = mock_text_embedding_class.call_count
|
|
|
|
|
|
2026-03-08 18:36:54 +00:00
|
|
|
await processor.on_embeddings(["text2"], model="model-a") # Same, no reload
|
2025-11-10 20:38:01 +00:00
|
|
|
call_count_after_a_repeat = mock_text_embedding_class.call_count
|
|
|
|
|
|
2026-03-08 18:36:54 +00:00
|
|
|
await processor.on_embeddings(["text3"], model="model-b") # Different, reload
|
2025-11-10 20:38:01 +00:00
|
|
|
call_count_after_b = mock_text_embedding_class.call_count
|
|
|
|
|
|
2026-03-08 18:36:54 +00:00
|
|
|
await processor.on_embeddings(["text4"], model="model-a") # Back to A, reload
|
2025-11-10 20:38:01 +00:00
|
|
|
call_count_after_a_again = mock_text_embedding_class.call_count
|
|
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
|
assert call_count_after_a == initial_call_count + 1 # First load
|
|
|
|
|
assert call_count_after_a_repeat == initial_call_count + 1 # No reload
|
|
|
|
|
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
|
|
|
|
|
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
|
|
|
|
|
|
|
|
|
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
|
|
|
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
|
|
|
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
|
|
|
|
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
|
|
|
|
"""Test that None model parameter falls back to default"""
|
|
|
|
|
# Arrange
|
|
|
|
|
mock_fastembed_instance = Mock()
|
|
|
|
|
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
|
|
|
|
mock_text_embedding_class.return_value = mock_fastembed_instance
|
|
|
|
|
mock_async_init.return_value = None
|
|
|
|
|
mock_embeddings_init.return_value = None
|
|
|
|
|
|
|
|
|
|
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
|
|
|
|
initial_count = mock_text_embedding_class.call_count
|
|
|
|
|
|
|
|
|
|
# Act
|
2026-03-08 18:36:54 +00:00
|
|
|
result = await processor.on_embeddings(["test text"], model=None)
|
2025-11-10 20:38:01 +00:00
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
|
# No reload, using cached default
|
|
|
|
|
assert mock_text_embedding_class.call_count == initial_count
|
|
|
|
|
assert processor.cached_model_name == "test-model"
|
|
|
|
|
|
|
|
|
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
|
|
|
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
|
|
|
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
|
|
|
|
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
|
|
|
|
"""Test initialization without model parameter uses module default"""
|
|
|
|
|
# Arrange
|
|
|
|
|
mock_fastembed_instance = Mock()
|
|
|
|
|
mock_text_embedding_class.return_value = mock_fastembed_instance
|
|
|
|
|
mock_async_init.return_value = None
|
|
|
|
|
mock_embeddings_init.return_value = None
|
|
|
|
|
|
|
|
|
|
# Act
|
|
|
|
|
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
|
|
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
|
# Should use default_model from module
|
|
|
|
|
expected_default = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
|
mock_text_embedding_class.assert_called_once_with(model_name=expected_default)
|
|
|
|
|
assert processor.default_model == expected_default
|
|
|
|
|
assert processor.cached_model_name == expected_default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
pytest.main([__file__])
|