From d9d4c91363843b5d1abd45a1b1474f094ce1af9f Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 10 Nov 2025 20:38:01 +0000 Subject: [PATCH] Dynamic embeddings model (#556) * Dynamic embeddings model selection * Added tests * HF embeddings are skipped, tests don't run with that package currently tests --- .../test_embeddings_service_contract.py | 157 +++++++++++++ .../test_fastembed_dynamic_model.py | 216 ++++++++++++++++++ .../test_huggingface_dynamic_model.py | 213 +++++++++++++++++ .../test_ollama_dynamic_model.py | 167 ++++++++++++++ .../trustgraph/base/embeddings_service.py | 12 +- .../trustgraph/embeddings/hf/hf.py | 27 ++- .../embeddings/fastembed/processor.py | 27 ++- .../trustgraph/embeddings/ollama/processor.py | 8 +- 8 files changed, 816 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_embeddings/test_embeddings_service_contract.py create mode 100644 tests/unit/test_embeddings/test_fastembed_dynamic_model.py create mode 100644 tests/unit/test_embeddings/test_huggingface_dynamic_model.py create mode 100644 tests/unit/test_embeddings/test_ollama_dynamic_model.py diff --git a/tests/unit/test_embeddings/test_embeddings_service_contract.py b/tests/unit/test_embeddings/test_embeddings_service_contract.py new file mode 100644 index 00000000..e53faf81 --- /dev/null +++ b/tests/unit/test_embeddings/test_embeddings_service_contract.py @@ -0,0 +1,157 @@ +""" +Contract tests for EmbeddingsService base class + +Tests the contract between the EmbeddingsService base class and its +implementations, ensuring proper integration of the model parameter handling. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from unittest import IsolatedAsyncioTestCase +from trustgraph.base import EmbeddingsService +from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error + + +class ConcreteEmbeddingsService(EmbeddingsService): + """Concrete implementation for testing the abstract base class""" + + def __init__(self, **params): + self.on_embeddings_calls = [] + self.default_model = params.get("model", "default-test-model") + # Don't call super().__init__ to avoid taskgroup requirements in tests + # We're only testing the on_embeddings interface + + async def on_embeddings(self, text, model=None): + """Implementation that tracks calls""" + self.on_embeddings_calls.append({ + "text": text, + "model": model + }) + # Return a simple embedding + return [[0.1, 0.2, 0.3]] + + +class TestEmbeddingsServiceModelParameterContract(IsolatedAsyncioTestCase): + """Test the model parameter contract in embeddings implementations""" + + async def test_on_embeddings_accepts_model_parameter(self): + """Test that on_embeddings method accepts optional model parameter""" + # Arrange + service = ConcreteEmbeddingsService(model="default-model") + + # Act + result1 = await service.on_embeddings("test text") + result2 = await service.on_embeddings("test text", model="custom-model") + result3 = await service.on_embeddings("test text", model=None) + + # Assert + assert len(service.on_embeddings_calls) == 3 + assert service.on_embeddings_calls[0]["model"] is None # No model specified + assert service.on_embeddings_calls[1]["model"] == "custom-model" + assert service.on_embeddings_calls[2]["model"] is None + + async def test_implementation_tracks_model_changes(self): + """Test that implementations properly track which model is requested""" + # Arrange + service = ConcreteEmbeddingsService(model="default-model") + + # Act - multiple requests with different models + await service.on_embeddings("text1", model="model-a") + await service.on_embeddings("text2", model="model-b") + await service.on_embeddings("text3") # Use default (None passed) + await service.on_embeddings("text4", model="model-a") + + # Assert + assert len(service.on_embeddings_calls) == 4 + assert service.on_embeddings_calls[0]["model"] == "model-a" + assert service.on_embeddings_calls[1]["model"] == "model-b" + assert service.on_embeddings_calls[2]["model"] is None + assert service.on_embeddings_calls[3]["model"] == "model-a" + + async def test_model_parameter_with_various_text_inputs(self): + """Test model parameter works with different text inputs""" + # Arrange + service = ConcreteEmbeddingsService(model="default-model") + + test_cases = [ + ("Simple text", "model-1"), + ("", "model-2"), + ("Unicode: δΈ–η•Œ 🌍", "model-3"), + ("Very " * 100 + "long text", None), + ] + + # Act + for text, model in test_cases: + await service.on_embeddings(text, model=model) + + # Assert + assert len(service.on_embeddings_calls) == len(test_cases) + for i, (text, model) in enumerate(test_cases): + assert service.on_embeddings_calls[i]["text"] == text + assert service.on_embeddings_calls[i]["model"] == model + + async def test_embeddings_return_format(self): + """Test that embeddings are returned in correct format""" + # Arrange + service = ConcreteEmbeddingsService(model="default-model") + + # Act + result = await service.on_embeddings("test text", model="test-model") + + # Assert + assert isinstance(result, list) + assert len(result) > 0 + assert isinstance(result[0], list) + assert all(isinstance(x, float) for x in result[0]) + + +class TestEmbeddingsResponseSchema: + """Test the EmbeddingsResponse schema contract""" + + def test_success_response(self): + """Test creating success response""" + # Act + response = EmbeddingsResponse( + error=None, + vectors=[[0.1, 0.2, 0.3]] + ) + + # Assert + assert response.error is None + assert response.vectors == [[0.1, 0.2, 0.3]] + + def test_error_response(self): + """Test creating error response""" + # Act + error = Error(type="test-error", message="Test message") + response = EmbeddingsResponse( + error=error, + vectors=None + ) + + # Assert + assert response.error is not None + assert response.error.type == "test-error" + assert response.error.message == "Test message" + assert response.vectors is None + + def test_response_with_multiple_vectors(self): + """Test response with multiple embedding vectors""" + # Act + vectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9] + ] + response = EmbeddingsResponse( + error=None, + vectors=vectors + ) + + # Assert + assert len(response.vectors) == 3 + assert response.vectors[0] == [0.1, 0.2, 0.3] + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py new file mode 100644 index 00000000..1c1fb883 --- /dev/null +++ b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py @@ -0,0 +1,216 @@ +""" +Unit tests for FastEmbed dynamic model loading + +Tests the model caching and dynamic loading functionality for FastEmbed +embeddings service. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from unittest import IsolatedAsyncioTestCase +from trustgraph.embeddings.fastembed.processor import Processor + + +class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): + """Test FastEmbed dynamic model loading and caching""" + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that default model is loaded during initialization""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + base_params = { + "id": "test-embeddings", + "concurrency": 1, + "model": "test-model", + "taskgroup": AsyncMock() + } + + # Act + processor = Processor(**base_params) + + # Assert + mock_text_embedding_class.assert_called_once_with(model_name="test-model") + assert processor.default_model == "test-model" + assert processor.cached_model_name == "test-model" + assert processor.embeddings is not None + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that using the same model doesn't reload it""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_text_embedding_class.reset_mock() + + # Act - use same model multiple times + processor._load_model("test-model") + processor._load_model("test-model") + processor._load_model("test-model") + + # Assert - model should not be reloaded + mock_text_embedding_class.assert_not_called() + assert processor.cached_model_name == "test-model" + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that changing model name triggers reload""" + # Arrange + mock_fastembed_instance = Mock() + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_text_embedding_class.reset_mock() + + # Act - switch to different model + processor._load_model("different-model") + + # Assert + mock_text_embedding_class.assert_called_once_with(model_name="different-model") + assert processor.cached_model_name == "different-model" + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that on_embeddings uses default model when no model specified""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_text_embedding_class.reset_mock() + + # Act + result = await processor.on_embeddings("test text") + + # Assert + mock_fastembed_instance.embed.assert_called_once_with(["test text"]) + assert processor.cached_model_name == "test-model" # Still using default + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that on_embeddings uses specified model when provided""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_text_embedding_class.reset_mock() + + # Act + result = await processor.on_embeddings("test text", model="custom-model") + + # Assert + mock_text_embedding_class.assert_called_once_with(model_name="custom-model") + assert processor.cached_model_name == "custom-model" + mock_fastembed_instance.embed.assert_called_once_with(["test text"]) + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test switching between multiple models""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + initial_call_count = mock_text_embedding_class.call_count + + # Act - switch between models + await processor.on_embeddings("text1", model="model-a") + call_count_after_a = mock_text_embedding_class.call_count + + await processor.on_embeddings("text2", model="model-a") # Same, no reload + call_count_after_a_repeat = mock_text_embedding_class.call_count + + await processor.on_embeddings("text3", model="model-b") # Different, reload + call_count_after_b = mock_text_embedding_class.call_count + + await processor.on_embeddings("text4", model="model-a") # Back to A, reload + call_count_after_a_again = mock_text_embedding_class.call_count + + # Assert + assert call_count_after_a == initial_call_count + 1 # First load + assert call_count_after_a_repeat == initial_call_count + 1 # No reload + assert call_count_after_b == initial_call_count + 2 # Reload for model-b + assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that None model parameter falls back to default""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + initial_count = mock_text_embedding_class.call_count + + # Act + result = await processor.on_embeddings("test text", model=None) + + # Assert + # No reload, using cached default + assert mock_text_embedding_class.call_count == initial_count + assert processor.cached_model_name == "test-model" + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test initialization without model parameter uses module default""" + # Arrange + mock_fastembed_instance = Mock() + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock()) + + # Assert + # Should use default_model from module + expected_default = "sentence-transformers/all-MiniLM-L6-v2" + mock_text_embedding_class.assert_called_once_with(model_name=expected_default) + assert processor.default_model == expected_default + assert processor.cached_model_name == expected_default + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_embeddings/test_huggingface_dynamic_model.py b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py new file mode 100644 index 00000000..aef6fc92 --- /dev/null +++ b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py @@ -0,0 +1,213 @@ +""" +Unit tests for HuggingFace dynamic model loading + +Tests the model caching and dynamic loading functionality for HuggingFace +embeddings service using LangChain's HuggingFaceEmbeddings. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from unittest import IsolatedAsyncioTestCase + +# Skip all tests in this module if trustgraph.embeddings.hf is not installed +pytest.importorskip("trustgraph.embeddings.hf") + +from trustgraph.embeddings.hf.hf import Processor + + +class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): + """Test HuggingFace dynamic model loading and caching""" + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that default model is loaded during initialization""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Assert + mock_hf_class.assert_called_once_with(model_name="test-model") + assert processor.default_model == "test-model" + assert processor.cached_model_name == "test-model" + assert processor.embeddings is not None + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that using the same model doesn't reload it""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_hf_class.reset_mock() + + # Act - use same model multiple times + processor._load_model("test-model") + processor._load_model("test-model") + processor._load_model("test-model") + + # Assert - model should not be reloaded + mock_hf_class.assert_not_called() + assert processor.cached_model_name == "test-model" + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that changing model name triggers reload""" + # Arrange + mock_hf_instance = Mock() + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_hf_class.reset_mock() + + # Act - switch to different model + processor._load_model("different-model") + + # Assert + mock_hf_class.assert_called_once_with(model_name="different-model") + assert processor.cached_model_name == "different-model" + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that on_embeddings uses default model when no model specified""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_hf_class.reset_mock() + + # Act + result = await processor.on_embeddings("test text") + + # Assert + mock_hf_instance.embed_documents.assert_called_once_with(["test text"]) + assert processor.cached_model_name == "test-model" # Still using default + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that on_embeddings uses specified model when provided""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_hf_class.reset_mock() + + # Act + result = await processor.on_embeddings("test text", model="custom-model") + + # Assert + mock_hf_class.assert_called_once_with(model_name="custom-model") + assert processor.cached_model_name == "custom-model" + mock_hf_instance.embed_documents.assert_called_once_with(["test text"]) + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test switching between multiple models""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + initial_call_count = mock_hf_class.call_count + + # Act - switch between models + await processor.on_embeddings("text1", model="model-a") + call_count_after_a = mock_hf_class.call_count + + await processor.on_embeddings("text2", model="model-a") # Same, no reload + call_count_after_a_repeat = mock_hf_class.call_count + + await processor.on_embeddings("text3", model="model-b") # Different, reload + call_count_after_b = mock_hf_class.call_count + + await processor.on_embeddings("text4", model="model-a") # Back to A, reload + call_count_after_a_again = mock_hf_class.call_count + + # Assert + assert call_count_after_a == initial_call_count + 1 # First load + assert call_count_after_a_repeat == initial_call_count + 1 # No reload + assert call_count_after_b == initial_call_count + 2 # Reload for model-b + assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that None model parameter falls back to default""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + initial_count = mock_hf_class.call_count + + # Act + result = await processor.on_embeddings("test text", model=None) + + # Assert + # No reload, using cached default + assert mock_hf_class.call_count == initial_count + assert processor.cached_model_name == "test-model" + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test initialization without model parameter uses module default""" + # Arrange + mock_hf_instance = Mock() + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock()) + + # Assert + # Should use default_model from module + expected_default = "all-MiniLM-L6-v2" + mock_hf_class.assert_called_once_with(model_name=expected_default) + assert processor.default_model == expected_default + assert processor.cached_model_name == expected_default + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_embeddings/test_ollama_dynamic_model.py b/tests/unit/test_embeddings/test_ollama_dynamic_model.py new file mode 100644 index 00000000..ca0f44bf --- /dev/null +++ b/tests/unit/test_embeddings/test_ollama_dynamic_model.py @@ -0,0 +1,167 @@ +""" +Unit tests for Ollama dynamic model loading + +Tests the dynamic model selection functionality for Ollama embeddings service. +Since Ollama is server-side, no model caching is needed on the client side. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from unittest import IsolatedAsyncioTestCase +from trustgraph.embeddings.ollama.processor import Processor + + +class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): + """Test Ollama dynamic model selection""" + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test that Ollama client is initialized with correct host""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test", concurrency=1, model="test-model", + ollama="http://localhost:11434", taskgroup=AsyncMock()) + + # Assert + mock_client_class.assert_called_once_with(host="http://localhost:11434") + assert processor.default_model == "test-model" + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test that on_embeddings uses default model when no model specified""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Act + result = await processor.on_embeddings("test text") + + # Assert + mock_ollama_client.embed.assert_called_once_with( + model="test-model", + input="test text" + ) + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test that on_embeddings uses specified model when provided""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Act + result = await processor.on_embeddings("test text", model="custom-model") + + # Assert + mock_ollama_client.embed.assert_called_once_with( + model="custom-model", + input="test text" + ) + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test switching between multiple models""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Act - switch between different models + await processor.on_embeddings("text1", model="model-a") + await processor.on_embeddings("text2", model="model-b") + await processor.on_embeddings("text3", model="model-a") + await processor.on_embeddings("text4") # Use default + + # Assert + calls = mock_ollama_client.embed.call_args_list + assert len(calls) == 4 + assert calls[0][1]['model'] == "model-a" + assert calls[1][1]['model'] == "model-b" + assert calls[2][1]['model'] == "model-a" + assert calls[3][1]['model'] == "test-model" # Default + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test that None model parameter falls back to default""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Act + result = await processor.on_embeddings("test text", model=None) + + # Assert + mock_ollama_client.embed.assert_called_once_with( + model="test-model", + input="test text" + ) + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test initialization without model parameter uses module default""" + # Arrange + mock_ollama_client = Mock() + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock()) + + # Assert + # Should use default_model from module + expected_default = "mxbai-embed-large" + assert processor.default_model == expected_default + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/trustgraph-base/trustgraph/base/embeddings_service.py b/trustgraph-base/trustgraph/base/embeddings_service.py index 556d32ff..a1442d41 100644 --- a/trustgraph-base/trustgraph/base/embeddings_service.py +++ b/trustgraph-base/trustgraph/base/embeddings_service.py @@ -9,7 +9,7 @@ from prometheus_client import Histogram from .. schema import EmbeddingsRequest, EmbeddingsResponse, Error from .. exceptions import TooManyRequests -from .. base import FlowProcessor, ConsumerSpec, ProducerSpec +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec # Module logger logger = logging.getLogger(__name__) @@ -45,6 +45,12 @@ class EmbeddingsService(FlowProcessor): ) ) + self.register_specification( + ParameterSpec( + name = "model", + ) + ) + async def on_request(self, msg, consumer, flow): try: @@ -57,7 +63,9 @@ class EmbeddingsService(FlowProcessor): logger.debug(f"Handling embeddings request {id}...") - vectors = await self.on_embeddings(request.text) + # Pass model from request if specified (non-empty), otherwise use default + model = flow("model") + vectors = await self.on_embeddings(request.text, model=model) await flow("response").send( EmbeddingsResponse( diff --git a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py index f1abbfae..8c4d571b 100755 --- a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py +++ b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py @@ -26,10 +26,31 @@ class Processor(EmbeddingsService): **params | { "model": model } ) - logger.info(f"Loading HuggingFace embeddings model: {model}") - self.embeddings = HuggingFaceEmbeddings(model_name=model) + self.default_model = model - async def on_embeddings(self, text): + # Cache for currently loaded model + self.cached_model_name = None + self.embeddings = None + + # Load the default model + self._load_model(model) + + def _load_model(self, model_name): + """Load a model, caching it for reuse""" + if self.cached_model_name != model_name: + logger.info(f"Loading HuggingFace embeddings model: {model_name}") + self.embeddings = HuggingFaceEmbeddings(model_name=model_name) + self.cached_model_name = model_name + logger.info(f"HuggingFace model {model_name} loaded successfully") + else: + logger.debug(f"Using cached model: {model_name}") + + async def on_embeddings(self, text, model=None): + + use_model = model or self.default_model + + # Reload model if it has changed + self._load_model(use_model) embeds = self.embeddings.embed_documents([text]) logger.debug("Embeddings generation complete") diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py index 0357e4a3..d1ce93ca 100755 --- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -27,10 +27,31 @@ class Processor(EmbeddingsService): **params | { "model": model } ) - logger.info("Loading FastEmbed model...") - self.embeddings = TextEmbedding(model_name = model) + self.default_model = model - async def on_embeddings(self, text): + # Cache for currently loaded model + self.cached_model_name = None + self.embeddings = None + + # Load the default model + self._load_model(model) + + def _load_model(self, model_name): + """Load a model, caching it for reuse""" + if self.cached_model_name != model_name: + logger.info(f"Loading FastEmbed model: {model_name}") + self.embeddings = TextEmbedding(model_name=model_name) + self.cached_model_name = model_name + logger.info(f"FastEmbed model {model_name} loaded successfully") + else: + logger.debug(f"Using cached model: {model_name}") + + async def on_embeddings(self, text, model=None): + + use_model = model or self.default_model + + # Reload model if it has changed + self._load_model(use_model) vecs = self.embeddings.embed([text]) diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index 3c0776f9..c951252e 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -28,12 +28,14 @@ class Processor(EmbeddingsService): ) self.client = Client(host=ollama) - self.model = model + self.default_model = model - async def on_embeddings(self, text): + async def on_embeddings(self, text, model=None): + + use_model = model or self.default_model embeds = self.client.embed( - model = self.model, + model = use_model, input = text )