Merge branch 'feat-exp-pool-bm25' into 'mgx_ops'

Feat exp pool bm25

See merge request pub/MetaGPT!327
This commit is contained in:
王金淋 2024-08-19 06:45:03 +00:00
commit c6cc5e2da3
14 changed files with 244 additions and 54 deletions

View file

@ -1,7 +1,10 @@
import pytest
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
from metagpt.configs.exp_pool_config import (
ExperiencePoolConfig,
ExperiencePoolRetrievalType,
)
from metagpt.configs.llm_config import LLMConfig
from metagpt.exp_pool.manager import Experience, ExperienceManager
from metagpt.exp_pool.schema import QueryType
@ -10,17 +13,19 @@ from metagpt.exp_pool.schema import QueryType
class TestExperienceManager:
@pytest.fixture
def mock_config(self):
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, enabled=True))
return Config(
llm=LLMConfig(),
exp_pool=ExperiencePoolConfig(
enable_write=True, enable_read=True, enabled=True, retrieval_type=ExperiencePoolRetrievalType.BM25
),
)
@pytest.fixture
def mock_storage(self, mocker):
engine = mocker.MagicMock()
engine.add_objs = mocker.MagicMock()
engine.aretrieve = mocker.AsyncMock(return_value=[])
engine._retriever = mocker.MagicMock()
engine._retriever._vector_store = mocker.MagicMock()
engine._retriever._vector_store._collection = mocker.MagicMock()
engine._retriever._vector_store._collection.count = mocker.MagicMock(return_value=10)
engine.count = mocker.MagicMock(return_value=10)
return engine
@pytest.fixture
@ -29,8 +34,33 @@ class TestExperienceManager:
manager._storage = mock_storage
return manager
def test_vector_store_property(self, exp_manager):
assert exp_manager.vector_store == exp_manager.storage._retriever._vector_store
def test_storage_property(self, exp_manager, mock_storage):
assert exp_manager.storage == mock_storage
def test_storage_property_initialization(self, mocker, mock_config):
mocker.patch.object(ExperienceManager, "_resolve_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
assert manager._storage is None
_ = manager.storage
assert manager._storage is not None
def test_create_exp_write_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_write = False
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_not_called()
def test_create_exp_write_enabled(self, exp_manager):
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_called_once_with([exp])
exp_manager.storage.persist.assert_called_once_with(exp_manager.config.exp_pool.persist_path)
@pytest.mark.asyncio
async def test_query_exps_read_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_read = False
result = await exp_manager.query_exps("query")
assert result == []
@pytest.mark.asyncio
async def test_query_exps_with_exact_match(self, exp_manager, mocker):
@ -65,14 +95,37 @@ class TestExperienceManager:
def test_get_exps_count(self, exp_manager):
assert exp_manager.get_exps_count() == 10
def test_create_exp_write_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_write = False
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_not_called()
def test_resolve_storage_bm25(self, mocker, mock_config):
mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.BM25
mocker.patch.object(ExperienceManager, "_create_bm25_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._resolve_storage()
manager._create_bm25_storage.assert_called_once()
assert storage is not None
@pytest.mark.asyncio
async def test_query_exps_read_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_read = False
result = await exp_manager.query_exps("query")
assert result == []
def test_resolve_storage_chroma(self, mocker, mock_config):
mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.CHROMA
mocker.patch.object(ExperienceManager, "_create_chroma_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._resolve_storage()
manager._create_chroma_storage.assert_called_once()
assert storage is not None
def test_create_bm25_storage(self, mocker, mock_config):
mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.engines.SimpleEngine.from_index", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.engines.SimpleEngine.get_obj_nodes", return_value=[])
mocker.patch("metagpt.rag.engines.SimpleEngine._resolve_embed_model", return_value=mocker.MagicMock())
mocker.patch("llama_index.core.VectorStoreIndex", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.schema.BM25RetrieverConfig", return_value=mocker.MagicMock())
mocker.patch("pathlib.Path.exists", return_value=False)
manager = ExperienceManager(config=mock_config)
storage = manager._create_bm25_storage()
assert storage is not None
def test_create_chroma_storage(self, mocker, mock_config):
mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._create_chroma_storage()
assert storage is not None

View file

@ -75,7 +75,7 @@ class TestSimpleEngine:
)
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files, fs=None)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)

View file

@ -1,5 +1,6 @@
import pytest
from metagpt.config2 import Config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
@ -12,7 +13,10 @@ class TestRAGEmbeddingFactory:
@pytest.fixture
def mock_config(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.config")
config = Config.default().model_copy(deep=True)
default = mocker.patch("metagpt.config2.Config.default")
default.return_value = config
return config
@staticmethod
def mock_openai_embedding(mocker):