Merge branch 'feat_werewolf' of github.com:better629/MetaGPT into feat_werewolf

This commit is contained in:
better629 2024-04-10 20:59:08 +08:00
commit bdfec451de
17 changed files with 436 additions and 40 deletions

View file

@ -1,5 +1,6 @@
import pytest
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
self.embedding_factory = RAGEmbeddingFactory()
@pytest.fixture
def mock_openai_embedding(self, mocker):
def mock_config(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.config")
@staticmethod
def mock_openai_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
@pytest.fixture
def mock_azure_embedding(self, mocker):
@staticmethod
def mock_azure_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
def test_get_rag_embedding_openai(self, mock_openai_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
@staticmethod
def mock_gemini_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
# Assert
mock_openai_embedding.assert_called_once()
@staticmethod
def mock_ollama_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
def test_get_rag_embedding_azure(self, mock_azure_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
# Assert
mock_azure_embedding.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
@pytest.mark.parametrize(
("mock_func", "embedding_type"),
[
(mock_openai_embedding, LLMType.OPENAI),
(mock_azure_embedding, LLMType.AZURE),
(mock_openai_embedding, EmbeddingType.OPENAI),
(mock_azure_embedding, EmbeddingType.AZURE),
(mock_gemini_embedding, EmbeddingType.GEMINI),
(mock_ollama_embedding, EmbeddingType.OLLAMA),
],
)
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
# Mock
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
mock = mock_func(mocker)
# Exec
self.embedding_factory.get_rag_embedding(embedding_type)
# Assert
mock.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_config):
# Mock
mock_openai_embedding = self.mock_openai_embedding(mocker)
mock_config.embedding.api_type = None
mock_config.llm.api_type = LLMType.OPENAI
# Exec
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
# Assert
mock_openai_embedding.assert_called_once()
@pytest.mark.parametrize(
"model, embed_batch_size, expected_params",
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
)
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
# Mock
mock_config.embedding.model = model
mock_config.embedding.embed_batch_size = embed_batch_size
# Setup
test_params = {}
# Exec
self.embedding_factory._try_set_model_and_batch_size(test_params)
# Assert
assert test_params == expected_params
def test_resolve_embedding_type(self, mock_config):
# Mock
mock_config.embedding.api_type = EmbeddingType.OPENAI
# Exec
embedding_type = self.embedding_factory._resolve_embedding_type()
# Assert
assert embedding_type == EmbeddingType.OPENAI
def test_resolve_embedding_type_exception(self, mock_config):
# Mock
mock_config.embedding.api_type = None
mock_config.llm.api_type = LLMType.GEMINI
# Assert
with pytest.raises(TypeError):
self.embedding_factory._resolve_embedding_type()
def test_raise_for_key(self):
with pytest.raises(ValueError):
self.embedding_factory._raise_for_key("key")