Dynamic embeddings model (#556)

* Dynamic embeddings model selection

* Added tests

* HF embeddings are skipped, tests don't run with that package currently tests
This commit is contained in:
cybermaggedon 2025-11-10 20:38:01 +00:00 committed by GitHub
parent 6129bb68c1
commit d9d4c91363
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 816 additions and 11 deletions

View file

@ -0,0 +1,157 @@
"""
Contract tests for EmbeddingsService base class
Tests the contract between the EmbeddingsService base class and its
implementations, ensuring proper integration of the model parameter handling.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base import EmbeddingsService
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
class ConcreteEmbeddingsService(EmbeddingsService):
"""Concrete implementation for testing the abstract base class"""
def __init__(self, **params):
self.on_embeddings_calls = []
self.default_model = params.get("model", "default-test-model")
# Don't call super().__init__ to avoid taskgroup requirements in tests
# We're only testing the on_embeddings interface
async def on_embeddings(self, text, model=None):
"""Implementation that tracks calls"""
self.on_embeddings_calls.append({
"text": text,
"model": model
})
# Return a simple embedding
return [[0.1, 0.2, 0.3]]
class TestEmbeddingsServiceModelParameterContract(IsolatedAsyncioTestCase):
"""Test the model parameter contract in embeddings implementations"""
async def test_on_embeddings_accepts_model_parameter(self):
"""Test that on_embeddings method accepts optional model parameter"""
# Arrange
service = ConcreteEmbeddingsService(model="default-model")
# Act
result1 = await service.on_embeddings("test text")
result2 = await service.on_embeddings("test text", model="custom-model")
result3 = await service.on_embeddings("test text", model=None)
# Assert
assert len(service.on_embeddings_calls) == 3
assert service.on_embeddings_calls[0]["model"] is None # No model specified
assert service.on_embeddings_calls[1]["model"] == "custom-model"
assert service.on_embeddings_calls[2]["model"] is None
async def test_implementation_tracks_model_changes(self):
"""Test that implementations properly track which model is requested"""
# Arrange
service = ConcreteEmbeddingsService(model="default-model")
# Act - multiple requests with different models
await service.on_embeddings("text1", model="model-a")
await service.on_embeddings("text2", model="model-b")
await service.on_embeddings("text3") # Use default (None passed)
await service.on_embeddings("text4", model="model-a")
# Assert
assert len(service.on_embeddings_calls) == 4
assert service.on_embeddings_calls[0]["model"] == "model-a"
assert service.on_embeddings_calls[1]["model"] == "model-b"
assert service.on_embeddings_calls[2]["model"] is None
assert service.on_embeddings_calls[3]["model"] == "model-a"
async def test_model_parameter_with_various_text_inputs(self):
"""Test model parameter works with different text inputs"""
# Arrange
service = ConcreteEmbeddingsService(model="default-model")
test_cases = [
("Simple text", "model-1"),
("", "model-2"),
("Unicode: 世界 🌍", "model-3"),
("Very " * 100 + "long text", None),
]
# Act
for text, model in test_cases:
await service.on_embeddings(text, model=model)
# Assert
assert len(service.on_embeddings_calls) == len(test_cases)
for i, (text, model) in enumerate(test_cases):
assert service.on_embeddings_calls[i]["text"] == text
assert service.on_embeddings_calls[i]["model"] == model
async def test_embeddings_return_format(self):
"""Test that embeddings are returned in correct format"""
# Arrange
service = ConcreteEmbeddingsService(model="default-model")
# Act
result = await service.on_embeddings("test text", model="test-model")
# Assert
assert isinstance(result, list)
assert len(result) > 0
assert isinstance(result[0], list)
assert all(isinstance(x, float) for x in result[0])
class TestEmbeddingsResponseSchema:
"""Test the EmbeddingsResponse schema contract"""
def test_success_response(self):
"""Test creating success response"""
# Act
response = EmbeddingsResponse(
error=None,
vectors=[[0.1, 0.2, 0.3]]
)
# Assert
assert response.error is None
assert response.vectors == [[0.1, 0.2, 0.3]]
def test_error_response(self):
"""Test creating error response"""
# Act
error = Error(type="test-error", message="Test message")
response = EmbeddingsResponse(
error=error,
vectors=None
)
# Assert
assert response.error is not None
assert response.error.type == "test-error"
assert response.error.message == "Test message"
assert response.vectors is None
def test_response_with_multiple_vectors(self):
"""Test response with multiple embedding vectors"""
# Act
vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
response = EmbeddingsResponse(
error=None,
vectors=vectors
)
# Assert
assert len(response.vectors) == 3
assert response.vectors[0] == [0.1, 0.2, 0.3]
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,216 @@
"""
Unit tests for FastEmbed dynamic model loading
Tests the model caching and dynamic loading functionality for FastEmbed
embeddings service.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from unittest import IsolatedAsyncioTestCase
from trustgraph.embeddings.fastembed.processor import Processor
class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test FastEmbed dynamic model loading and caching"""
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that default model is loaded during initialization"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
base_params = {
"id": "test-embeddings",
"concurrency": 1,
"model": "test-model",
"taskgroup": AsyncMock()
}
# Act
processor = Processor(**base_params)
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="test-model")
assert processor.default_model == "test-model"
assert processor.cached_model_name == "test-model"
assert processor.embeddings is not None
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that using the same model doesn't reload it"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act - use same model multiple times
processor._load_model("test-model")
processor._load_model("test-model")
processor._load_model("test-model")
# Assert - model should not be reloaded
mock_text_embedding_class.assert_not_called()
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that changing model name triggers reload"""
# Arrange
mock_fastembed_instance = Mock()
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act - switch to different model
processor._load_model("different-model")
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="different-model")
assert processor.cached_model_name == "different-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act
result = await processor.on_embeddings("test text")
# Assert
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act
result = await processor.on_embeddings("test text", model="custom-model")
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
assert processor.cached_model_name == "custom-model"
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test switching between multiple models"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_call_count = mock_text_embedding_class.call_count
# Act - switch between models
await processor.on_embeddings("text1", model="model-a")
call_count_after_a = mock_text_embedding_class.call_count
await processor.on_embeddings("text2", model="model-a") # Same, no reload
call_count_after_a_repeat = mock_text_embedding_class.call_count
await processor.on_embeddings("text3", model="model-b") # Different, reload
call_count_after_b = mock_text_embedding_class.call_count
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
call_count_after_a_again = mock_text_embedding_class.call_count
# Assert
assert call_count_after_a == initial_call_count + 1 # First load
assert call_count_after_a_repeat == initial_call_count + 1 # No reload
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_count = mock_text_embedding_class.call_count
# Act
result = await processor.on_embeddings("test text", model=None)
# Assert
# No reload, using cached default
assert mock_text_embedding_class.call_count == initial_count
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_fastembed_instance = Mock()
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
# Assert
# Should use default_model from module
expected_default = "sentence-transformers/all-MiniLM-L6-v2"
mock_text_embedding_class.assert_called_once_with(model_name=expected_default)
assert processor.default_model == expected_default
assert processor.cached_model_name == expected_default
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,213 @@
"""
Unit tests for HuggingFace dynamic model loading
Tests the model caching and dynamic loading functionality for HuggingFace
embeddings service using LangChain's HuggingFaceEmbeddings.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from unittest import IsolatedAsyncioTestCase
# Skip all tests in this module if trustgraph.embeddings.hf is not installed
pytest.importorskip("trustgraph.embeddings.hf")
from trustgraph.embeddings.hf.hf import Processor
class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test HuggingFace dynamic model loading and caching"""
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that default model is loaded during initialization"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Assert
mock_hf_class.assert_called_once_with(model_name="test-model")
assert processor.default_model == "test-model"
assert processor.cached_model_name == "test-model"
assert processor.embeddings is not None
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that using the same model doesn't reload it"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_hf_class.reset_mock()
# Act - use same model multiple times
processor._load_model("test-model")
processor._load_model("test-model")
processor._load_model("test-model")
# Assert - model should not be reloaded
mock_hf_class.assert_not_called()
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that changing model name triggers reload"""
# Arrange
mock_hf_instance = Mock()
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_hf_class.reset_mock()
# Act - switch to different model
processor._load_model("different-model")
# Assert
mock_hf_class.assert_called_once_with(model_name="different-model")
assert processor.cached_model_name == "different-model"
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_hf_class.reset_mock()
# Act
result = await processor.on_embeddings("test text")
# Assert
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_hf_class.reset_mock()
# Act
result = await processor.on_embeddings("test text", model="custom-model")
# Assert
mock_hf_class.assert_called_once_with(model_name="custom-model")
assert processor.cached_model_name == "custom-model"
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test switching between multiple models"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_call_count = mock_hf_class.call_count
# Act - switch between models
await processor.on_embeddings("text1", model="model-a")
call_count_after_a = mock_hf_class.call_count
await processor.on_embeddings("text2", model="model-a") # Same, no reload
call_count_after_a_repeat = mock_hf_class.call_count
await processor.on_embeddings("text3", model="model-b") # Different, reload
call_count_after_b = mock_hf_class.call_count
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
call_count_after_a_again = mock_hf_class.call_count
# Assert
assert call_count_after_a == initial_call_count + 1 # First load
assert call_count_after_a_repeat == initial_call_count + 1 # No reload
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_count = mock_hf_class.call_count
# Act
result = await processor.on_embeddings("test text", model=None)
# Assert
# No reload, using cached default
assert mock_hf_class.call_count == initial_count
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_hf_instance = Mock()
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
# Assert
# Should use default_model from module
expected_default = "all-MiniLM-L6-v2"
mock_hf_class.assert_called_once_with(model_name=expected_default)
assert processor.default_model == expected_default
assert processor.cached_model_name == expected_default
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,167 @@
"""
Unit tests for Ollama dynamic model loading
Tests the dynamic model selection functionality for Ollama embeddings service.
Since Ollama is server-side, no model caching is needed on the client side.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from unittest import IsolatedAsyncioTestCase
from trustgraph.embeddings.ollama.processor import Processor
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test Ollama dynamic model selection"""
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized with correct host"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test", concurrency=1, model="test-model",
ollama="http://localhost:11434", taskgroup=AsyncMock())
# Assert
mock_client_class.assert_called_once_with(host="http://localhost:11434")
assert processor.default_model == "test-model"
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text")
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="test-model",
input="test text"
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text", model="custom-model")
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="custom-model",
input="test text"
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test switching between multiple models"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act - switch between different models
await processor.on_embeddings("text1", model="model-a")
await processor.on_embeddings("text2", model="model-b")
await processor.on_embeddings("text3", model="model-a")
await processor.on_embeddings("text4") # Use default
# Assert
calls = mock_ollama_client.embed.call_args_list
assert len(calls) == 4
assert calls[0][1]['model'] == "model-a"
assert calls[1][1]['model'] == "model-b"
assert calls[2][1]['model'] == "model-a"
assert calls[3][1]['model'] == "test-model" # Default
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text", model=None)
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="test-model",
input="test text"
)
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_ollama_client = Mock()
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
# Assert
# Should use default_model from module
expected_default = "mxbai-embed-large"
assert processor.default_model == expected_default
if __name__ == '__main__':
pytest.main([__file__])