diff --git a/config/config2.example.yaml b/config/config2.example.yaml index ba2b5527c..a24892c2a 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -80,6 +80,7 @@ exp_pool: enable_write: false persist_path: .chroma_exp_data # The directory. retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding. + use_llm_ranker: false # If `use_llm_ranker` is true, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index ad918b481..7611dda27 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -21,3 +21,4 @@ class ExperiencePoolConfig(YamlModel): retrieval_type: ExperiencePoolRetrievalType = Field( default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool." ) + use_llm_ranker: bool = Field(default=False, description="Use LLM Reranker to get better result.") diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 56cc43970..afa2459d9 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -117,18 +117,14 @@ class ExperienceManager(BaseModel): try: from metagpt.rag.engines import SimpleEngine - from metagpt.rag.schema import ( - BM25IndexConfig, - BM25RetrieverConfig, - LLMRankerConfig, - ) + from metagpt.rag.schema import BM25IndexConfig, BM25RetrieverConfig except ImportError: raise ImportError("To use the experience pool, you need to install the rag module.") persist_path = Path(self.config.exp_pool.persist_path) docstore_path = persist_path / "docstore.json" - ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] + ranker_configs = self._get_ranker_configs() if not docstore_path.exists(): logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.") @@ -163,7 +159,7 @@ class ExperienceManager(BaseModel): try: from metagpt.rag.engines import SimpleEngine - from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig + from metagpt.rag.schema import ChromaRetrieverConfig except ImportError: raise ImportError("To use the experience pool, you need to install the rag module.") @@ -174,12 +170,25 @@ class ExperienceManager(BaseModel): similarity_top_k=DEFAULT_SIMILARITY_TOP_K, ) ] - ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] + ranker_configs = self._get_ranker_configs() storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) return storage + def _get_ranker_configs(self): + """Returns ranker configurations based on the configuration. + + If `use_llm_ranker` is True, returns a list with one `LLMRankerConfig` + instance. Otherwise, returns an empty list. + + Returns: + list: A list of `LLMRankerConfig` instances or an empty list. + """ + from metagpt.rag.schema import LLMRankerConfig + + return [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] if self.config.exp_pool.use_llm_ranker else [] + _exp_manager = None diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py index 933232031..b0e4e8537 100644 --- a/tests/metagpt/exp_pool/test_manager.py +++ b/tests/metagpt/exp_pool/test_manager.py @@ -7,7 +7,7 @@ from metagpt.configs.exp_pool_config import ( ) from metagpt.configs.llm_config import LLMConfig from metagpt.exp_pool.manager import Experience, ExperienceManager -from metagpt.exp_pool.schema import QueryType +from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, QueryType class TestExperienceManager: @@ -129,3 +129,16 @@ class TestExperienceManager: manager = ExperienceManager(config=mock_config) storage = manager._create_chroma_storage() assert storage is not None + + def test_get_ranker_configs_use_llm_ranker_true(self, mock_config): + mock_config.exp_pool.use_llm_ranker = True + manager = ExperienceManager(config=mock_config) + ranker_configs = manager._get_ranker_configs() + assert len(ranker_configs) == 1 + assert ranker_configs[0].top_n == DEFAULT_SIMILARITY_TOP_K + + def test_get_ranker_configs_use_llm_ranker_false(self, mock_config): + mock_config.exp_pool.use_llm_ranker = False + manager = ExperienceManager(config=mock_config) + ranker_configs = manager._get_ranker_configs() + assert len(ranker_configs) == 0