mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
Dynamic embeddings model (#556)
* Dynamic embeddings model selection * Added tests * HF embeddings are skipped, tests don't run with that package currently tests
This commit is contained in:
parent
6129bb68c1
commit
d9d4c91363
8 changed files with 816 additions and 11 deletions
157
tests/unit/test_embeddings/test_embeddings_service_contract.py
Normal file
157
tests/unit/test_embeddings/test_embeddings_service_contract.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""
|
||||
Contract tests for EmbeddingsService base class
|
||||
|
||||
Tests the contract between the EmbeddingsService base class and its
|
||||
implementations, ensuring proper integration of the model parameter handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from trustgraph.base import EmbeddingsService
|
||||
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
|
||||
|
||||
|
||||
class ConcreteEmbeddingsService(EmbeddingsService):
|
||||
"""Concrete implementation for testing the abstract base class"""
|
||||
|
||||
def __init__(self, **params):
|
||||
self.on_embeddings_calls = []
|
||||
self.default_model = params.get("model", "default-test-model")
|
||||
# Don't call super().__init__ to avoid taskgroup requirements in tests
|
||||
# We're only testing the on_embeddings interface
|
||||
|
||||
async def on_embeddings(self, text, model=None):
|
||||
"""Implementation that tracks calls"""
|
||||
self.on_embeddings_calls.append({
|
||||
"text": text,
|
||||
"model": model
|
||||
})
|
||||
# Return a simple embedding
|
||||
return [[0.1, 0.2, 0.3]]
|
||||
|
||||
|
||||
class TestEmbeddingsServiceModelParameterContract(IsolatedAsyncioTestCase):
|
||||
"""Test the model parameter contract in embeddings implementations"""
|
||||
|
||||
async def test_on_embeddings_accepts_model_parameter(self):
|
||||
"""Test that on_embeddings method accepts optional model parameter"""
|
||||
# Arrange
|
||||
service = ConcreteEmbeddingsService(model="default-model")
|
||||
|
||||
# Act
|
||||
result1 = await service.on_embeddings("test text")
|
||||
result2 = await service.on_embeddings("test text", model="custom-model")
|
||||
result3 = await service.on_embeddings("test text", model=None)
|
||||
|
||||
# Assert
|
||||
assert len(service.on_embeddings_calls) == 3
|
||||
assert service.on_embeddings_calls[0]["model"] is None # No model specified
|
||||
assert service.on_embeddings_calls[1]["model"] == "custom-model"
|
||||
assert service.on_embeddings_calls[2]["model"] is None
|
||||
|
||||
async def test_implementation_tracks_model_changes(self):
|
||||
"""Test that implementations properly track which model is requested"""
|
||||
# Arrange
|
||||
service = ConcreteEmbeddingsService(model="default-model")
|
||||
|
||||
# Act - multiple requests with different models
|
||||
await service.on_embeddings("text1", model="model-a")
|
||||
await service.on_embeddings("text2", model="model-b")
|
||||
await service.on_embeddings("text3") # Use default (None passed)
|
||||
await service.on_embeddings("text4", model="model-a")
|
||||
|
||||
# Assert
|
||||
assert len(service.on_embeddings_calls) == 4
|
||||
assert service.on_embeddings_calls[0]["model"] == "model-a"
|
||||
assert service.on_embeddings_calls[1]["model"] == "model-b"
|
||||
assert service.on_embeddings_calls[2]["model"] is None
|
||||
assert service.on_embeddings_calls[3]["model"] == "model-a"
|
||||
|
||||
async def test_model_parameter_with_various_text_inputs(self):
|
||||
"""Test model parameter works with different text inputs"""
|
||||
# Arrange
|
||||
service = ConcreteEmbeddingsService(model="default-model")
|
||||
|
||||
test_cases = [
|
||||
("Simple text", "model-1"),
|
||||
("", "model-2"),
|
||||
("Unicode: 世界 🌍", "model-3"),
|
||||
("Very " * 100 + "long text", None),
|
||||
]
|
||||
|
||||
# Act
|
||||
for text, model in test_cases:
|
||||
await service.on_embeddings(text, model=model)
|
||||
|
||||
# Assert
|
||||
assert len(service.on_embeddings_calls) == len(test_cases)
|
||||
for i, (text, model) in enumerate(test_cases):
|
||||
assert service.on_embeddings_calls[i]["text"] == text
|
||||
assert service.on_embeddings_calls[i]["model"] == model
|
||||
|
||||
async def test_embeddings_return_format(self):
|
||||
"""Test that embeddings are returned in correct format"""
|
||||
# Arrange
|
||||
service = ConcreteEmbeddingsService(model="default-model")
|
||||
|
||||
# Act
|
||||
result = await service.on_embeddings("test text", model="test-model")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
assert isinstance(result[0], list)
|
||||
assert all(isinstance(x, float) for x in result[0])
|
||||
|
||||
|
||||
class TestEmbeddingsResponseSchema:
|
||||
"""Test the EmbeddingsResponse schema contract"""
|
||||
|
||||
def test_success_response(self):
|
||||
"""Test creating success response"""
|
||||
# Act
|
||||
response = EmbeddingsResponse(
|
||||
error=None,
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.error is None
|
||||
assert response.vectors == [[0.1, 0.2, 0.3]]
|
||||
|
||||
def test_error_response(self):
|
||||
"""Test creating error response"""
|
||||
# Act
|
||||
error = Error(type="test-error", message="Test message")
|
||||
response = EmbeddingsResponse(
|
||||
error=error,
|
||||
vectors=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.error is not None
|
||||
assert response.error.type == "test-error"
|
||||
assert response.error.message == "Test message"
|
||||
assert response.vectors is None
|
||||
|
||||
def test_response_with_multiple_vectors(self):
|
||||
"""Test response with multiple embedding vectors"""
|
||||
# Act
|
||||
vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
response = EmbeddingsResponse(
|
||||
error=None,
|
||||
vectors=vectors
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(response.vectors) == 3
|
||||
assert response.vectors[0] == [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
216
tests/unit/test_embeddings/test_fastembed_dynamic_model.py
Normal file
216
tests/unit/test_embeddings/test_fastembed_dynamic_model.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""
|
||||
Unit tests for FastEmbed dynamic model loading
|
||||
|
||||
Tests the model caching and dynamic loading functionality for FastEmbed
|
||||
embeddings service.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from trustgraph.embeddings.fastembed.processor import Processor
|
||||
|
||||
|
||||
class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||
"""Test FastEmbed dynamic model loading and caching"""
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
||||
"""Test that default model is loaded during initialization"""
|
||||
# Arrange
|
||||
mock_fastembed_instance = Mock()
|
||||
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
||||
mock_text_embedding_class.return_value = mock_fastembed_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
base_params = {
|
||||
"id": "test-embeddings",
|
||||
"concurrency": 1,
|
||||
"model": "test-model",
|
||||
"taskgroup": AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**base_params)
|
||||
|
||||
# Assert
|
||||
mock_text_embedding_class.assert_called_once_with(model_name="test-model")
|
||||
assert processor.default_model == "test-model"
|
||||
assert processor.cached_model_name == "test-model"
|
||||
assert processor.embeddings is not None
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
||||
"""Test that using the same model doesn't reload it"""
|
||||
# Arrange
|
||||
mock_fastembed_instance = Mock()
|
||||
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
||||
mock_text_embedding_class.return_value = mock_fastembed_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act - use same model multiple times
|
||||
processor._load_model("test-model")
|
||||
processor._load_model("test-model")
|
||||
processor._load_model("test-model")
|
||||
|
||||
# Assert - model should not be reloaded
|
||||
mock_text_embedding_class.assert_not_called()
|
||||
assert processor.cached_model_name == "test-model"
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
||||
"""Test that changing model name triggers reload"""
|
||||
# Arrange
|
||||
mock_fastembed_instance = Mock()
|
||||
mock_text_embedding_class.return_value = mock_fastembed_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act - switch to different model
|
||||
processor._load_model("different-model")
|
||||
|
||||
# Assert
|
||||
mock_text_embedding_class.assert_called_once_with(model_name="different-model")
|
||||
assert processor.cached_model_name == "different-model"
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
||||
"""Test that on_embeddings uses default model when no model specified"""
|
||||
# Arrange
|
||||
mock_fastembed_instance = Mock()
|
||||
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
||||
mock_text_embedding_class.return_value = mock_fastembed_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
|
||||
# Assert
|
||||
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
||||
assert processor.cached_model_name == "test-model" # Still using default
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
||||
"""Test that on_embeddings uses specified model when provided"""
|
||||
# Arrange
|
||||
mock_fastembed_instance = Mock()
|
||||
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
||||
mock_text_embedding_class.return_value = mock_fastembed_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
|
||||
assert processor.cached_model_name == "custom-model"
|
||||
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
||||
"""Test switching between multiple models"""
|
||||
# Arrange
|
||||
mock_fastembed_instance = Mock()
|
||||
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
||||
mock_text_embedding_class.return_value = mock_fastembed_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
initial_call_count = mock_text_embedding_class.call_count
|
||||
|
||||
# Act - switch between models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
call_count_after_a = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text2", model="model-a") # Same, no reload
|
||||
call_count_after_a_repeat = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text3", model="model-b") # Different, reload
|
||||
call_count_after_b = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
|
||||
call_count_after_a_again = mock_text_embedding_class.call_count
|
||||
|
||||
# Assert
|
||||
assert call_count_after_a == initial_call_count + 1 # First load
|
||||
assert call_count_after_a_repeat == initial_call_count + 1 # No reload
|
||||
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
|
||||
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
||||
"""Test that None model parameter falls back to default"""
|
||||
# Arrange
|
||||
mock_fastembed_instance = Mock()
|
||||
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
|
||||
mock_text_embedding_class.return_value = mock_fastembed_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
initial_count = mock_text_embedding_class.call_count
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
|
||||
# Assert
|
||||
# No reload, using cached default
|
||||
assert mock_text_embedding_class.call_count == initial_count
|
||||
assert processor.cached_model_name == "test-model"
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
|
||||
"""Test initialization without model parameter uses module default"""
|
||||
# Arrange
|
||||
mock_fastembed_instance = Mock()
|
||||
mock_text_embedding_class.return_value = mock_fastembed_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
# Act
|
||||
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
|
||||
|
||||
# Assert
|
||||
# Should use default_model from module
|
||||
expected_default = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
mock_text_embedding_class.assert_called_once_with(model_name=expected_default)
|
||||
assert processor.default_model == expected_default
|
||||
assert processor.cached_model_name == expected_default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
213
tests/unit/test_embeddings/test_huggingface_dynamic_model.py
Normal file
213
tests/unit/test_embeddings/test_huggingface_dynamic_model.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
"""
|
||||
Unit tests for HuggingFace dynamic model loading
|
||||
|
||||
Tests the model caching and dynamic loading functionality for HuggingFace
|
||||
embeddings service using LangChain's HuggingFaceEmbeddings.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Skip all tests in this module if trustgraph.embeddings.hf is not installed
|
||||
pytest.importorskip("trustgraph.embeddings.hf")
|
||||
|
||||
from trustgraph.embeddings.hf.hf import Processor
|
||||
|
||||
|
||||
class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||
"""Test HuggingFace dynamic model loading and caching"""
|
||||
|
||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||
"""Test that default model is loaded during initialization"""
|
||||
# Arrange
|
||||
mock_hf_instance = Mock()
|
||||
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_hf_class.return_value = mock_hf_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
# Act
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Assert
|
||||
mock_hf_class.assert_called_once_with(model_name="test-model")
|
||||
assert processor.default_model == "test-model"
|
||||
assert processor.cached_model_name == "test-model"
|
||||
assert processor.embeddings is not None
|
||||
|
||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||
"""Test that using the same model doesn't reload it"""
|
||||
# Arrange
|
||||
mock_hf_instance = Mock()
|
||||
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_hf_class.return_value = mock_hf_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
mock_hf_class.reset_mock()
|
||||
|
||||
# Act - use same model multiple times
|
||||
processor._load_model("test-model")
|
||||
processor._load_model("test-model")
|
||||
processor._load_model("test-model")
|
||||
|
||||
# Assert - model should not be reloaded
|
||||
mock_hf_class.assert_not_called()
|
||||
assert processor.cached_model_name == "test-model"
|
||||
|
||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||
"""Test that changing model name triggers reload"""
|
||||
# Arrange
|
||||
mock_hf_instance = Mock()
|
||||
mock_hf_class.return_value = mock_hf_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
mock_hf_class.reset_mock()
|
||||
|
||||
# Act - switch to different model
|
||||
processor._load_model("different-model")
|
||||
|
||||
# Assert
|
||||
mock_hf_class.assert_called_once_with(model_name="different-model")
|
||||
assert processor.cached_model_name == "different-model"
|
||||
|
||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||
"""Test that on_embeddings uses default model when no model specified"""
|
||||
# Arrange
|
||||
mock_hf_instance = Mock()
|
||||
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_hf_class.return_value = mock_hf_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
mock_hf_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
|
||||
# Assert
|
||||
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
|
||||
assert processor.cached_model_name == "test-model" # Still using default
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||
"""Test that on_embeddings uses specified model when provided"""
|
||||
# Arrange
|
||||
mock_hf_instance = Mock()
|
||||
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_hf_class.return_value = mock_hf_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
mock_hf_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_hf_class.assert_called_once_with(model_name="custom-model")
|
||||
assert processor.cached_model_name == "custom-model"
|
||||
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
|
||||
|
||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||
"""Test switching between multiple models"""
|
||||
# Arrange
|
||||
mock_hf_instance = Mock()
|
||||
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_hf_class.return_value = mock_hf_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
initial_call_count = mock_hf_class.call_count
|
||||
|
||||
# Act - switch between models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
call_count_after_a = mock_hf_class.call_count
|
||||
|
||||
await processor.on_embeddings("text2", model="model-a") # Same, no reload
|
||||
call_count_after_a_repeat = mock_hf_class.call_count
|
||||
|
||||
await processor.on_embeddings("text3", model="model-b") # Different, reload
|
||||
call_count_after_b = mock_hf_class.call_count
|
||||
|
||||
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
|
||||
call_count_after_a_again = mock_hf_class.call_count
|
||||
|
||||
# Assert
|
||||
assert call_count_after_a == initial_call_count + 1 # First load
|
||||
assert call_count_after_a_repeat == initial_call_count + 1 # No reload
|
||||
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
|
||||
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
|
||||
|
||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||
"""Test that None model parameter falls back to default"""
|
||||
# Arrange
|
||||
mock_hf_instance = Mock()
|
||||
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_hf_class.return_value = mock_hf_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
initial_count = mock_hf_class.call_count
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
|
||||
# Assert
|
||||
# No reload, using cached default
|
||||
assert mock_hf_class.call_count == initial_count
|
||||
assert processor.cached_model_name == "test-model"
|
||||
|
||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||
"""Test initialization without model parameter uses module default"""
|
||||
# Arrange
|
||||
mock_hf_instance = Mock()
|
||||
mock_hf_class.return_value = mock_hf_instance
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
# Act
|
||||
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
|
||||
|
||||
# Assert
|
||||
# Should use default_model from module
|
||||
expected_default = "all-MiniLM-L6-v2"
|
||||
mock_hf_class.assert_called_once_with(model_name=expected_default)
|
||||
assert processor.default_model == expected_default
|
||||
assert processor.cached_model_name == expected_default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
167
tests/unit/test_embeddings/test_ollama_dynamic_model.py
Normal file
167
tests/unit/test_embeddings/test_ollama_dynamic_model.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""
|
||||
Unit tests for Ollama dynamic model loading
|
||||
|
||||
Tests the dynamic model selection functionality for Ollama embeddings service.
|
||||
Since Ollama is server-side, no model caching is needed on the client side.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from trustgraph.embeddings.ollama.processor import Processor
|
||||
|
||||
|
||||
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||
"""Test Ollama dynamic model selection"""
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that Ollama client is initialized with correct host"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
# Act
|
||||
processor = Processor(id="test", concurrency=1, model="test-model",
|
||||
ollama="http://localhost:11434", taskgroup=AsyncMock())
|
||||
|
||||
# Assert
|
||||
mock_client_class.assert_called_once_with(host="http://localhost:11434")
|
||||
assert processor.default_model == "test-model"
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that on_embeddings uses default model when no model specified"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that on_embeddings uses specified model when provided"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="custom-model",
|
||||
input="test text"
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test switching between multiple models"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act - switch between different models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
await processor.on_embeddings("text2", model="model-b")
|
||||
await processor.on_embeddings("text3", model="model-a")
|
||||
await processor.on_embeddings("text4") # Use default
|
||||
|
||||
# Assert
|
||||
calls = mock_ollama_client.embed.call_args_list
|
||||
assert len(calls) == 4
|
||||
assert calls[0][1]['model'] == "model-a"
|
||||
assert calls[1][1]['model'] == "model-b"
|
||||
assert calls[2][1]['model'] == "model-a"
|
||||
assert calls[3][1]['model'] == "test-model" # Default
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that None model parameter falls back to default"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
)
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test initialization without model parameter uses module default"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
||||
# Act
|
||||
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
|
||||
|
||||
# Assert
|
||||
# Should use default_model from module
|
||||
expected_default = "mxbai-embed-large"
|
||||
assert processor.default_model == expected_default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue