mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
Merge branch 'feat_werewolf' of github.com:better629/MetaGPT into feat_werewolf
This commit is contained in:
commit
bdfec451de
17 changed files with 436 additions and 40 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
|
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
|
|||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
|
||||
@staticmethod
|
||||
def mock_openai_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
@staticmethod
|
||||
def mock_azure_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
@staticmethod
|
||||
def mock_gemini_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
@staticmethod
|
||||
def mock_ollama_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
@pytest.mark.parametrize(
|
||||
("mock_func", "embedding_type"),
|
||||
[
|
||||
(mock_openai_embedding, LLMType.OPENAI),
|
||||
(mock_azure_embedding, LLMType.AZURE),
|
||||
(mock_openai_embedding, EmbeddingType.OPENAI),
|
||||
(mock_azure_embedding, EmbeddingType.AZURE),
|
||||
(mock_gemini_embedding, EmbeddingType.GEMINI),
|
||||
(mock_ollama_embedding, EmbeddingType.OLLAMA),
|
||||
],
|
||||
)
|
||||
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock = mock_func(mocker)
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(embedding_type)
|
||||
|
||||
# Assert
|
||||
mock.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_config):
|
||||
# Mock
|
||||
mock_openai_embedding = self.mock_openai_embedding(mocker)
|
||||
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
|
|
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
|
|||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, embed_batch_size, expected_params",
|
||||
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
|
||||
)
|
||||
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
|
||||
# Mock
|
||||
mock_config.embedding.model = model
|
||||
mock_config.embedding.embed_batch_size = embed_batch_size
|
||||
|
||||
# Setup
|
||||
test_params = {}
|
||||
|
||||
# Exec
|
||||
self.embedding_factory._try_set_model_and_batch_size(test_params)
|
||||
|
||||
# Assert
|
||||
assert test_params == expected_params
|
||||
|
||||
def test_resolve_embedding_type(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = EmbeddingType.OPENAI
|
||||
|
||||
# Exec
|
||||
embedding_type = self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
# Assert
|
||||
assert embedding_type == EmbeddingType.OPENAI
|
||||
|
||||
def test_resolve_embedding_type_exception(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.GEMINI
|
||||
|
||||
# Assert
|
||||
with pytest.raises(TypeError):
|
||||
self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
def test_raise_for_key(self):
|
||||
with pytest.raises(ValueError):
|
||||
self.embedding_factory._raise_for_key("key")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue