diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 1a753b855..ba2b5527c 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -79,7 +79,7 @@ exp_pool: enable_read: false enable_write: false persist_path: .chroma_exp_data # The directory. - storage_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding. + retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index 11669a301..ad918b481 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -5,7 +5,7 @@ from pydantic import Field from metagpt.utils.yaml_model import YamlModel -class ExperiencePoolStorageType(Enum): +class ExperiencePoolRetrievalType(Enum): BM25 = "bm25" CHROMA = "chroma" @@ -18,6 +18,6 @@ class ExperiencePoolConfig(YamlModel): enable_read: bool = Field(default=False, description="Enable to read from experience pool.") enable_write: bool = Field(default=False, description="Enable to write to experience pool.") persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.") - storage_type: ExperiencePoolStorageType = Field( - default=ExperiencePoolStorageType.BM25, description="The storage type for experience pool." + retrieval_type: ExperiencePoolRetrievalType = Field( + default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool." ) diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index ed5f5e068..b9ce61e29 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -132,14 +132,14 @@ class ExpCacheHandler(BaseModel): """Fetch experiences by query_type.""" self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag) - logger.debug(f"Found {len(self._exps)} experiences for req '{self._req[:20]}...' and tag '{self.tag}'") + logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'") async def get_one_perfect_exp(self) -> Optional[Any]: """Get a potentially perfect experience, and resolve resp.""" for exp in self._exps: if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs): - logger.debug(f"Got one perfect experience for req '{exp.req[:20]}...'") + logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'") return self.serializer.deserialize_resp(exp.resp) return None diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 8388eabb7..40528811a 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field from metagpt.config2 import Config -from metagpt.configs.exp_pool_config import ExperiencePoolStorageType +from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType from metagpt.exp_pool.schema import ( DEFAULT_COLLECTION_NAME, DEFAULT_SIMILARITY_TOP_K, @@ -95,11 +95,11 @@ class ExperienceManager(BaseModel): """Selects the appropriate storage creation method based on the configured storage type.""" storage_creators = { - ExperiencePoolStorageType.BM25: self._create_bm25_storage, - ExperiencePoolStorageType.CHROMA: self._create_chroma_storage, + ExperiencePoolRetrievalType.BM25: self._create_bm25_storage, + ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage, } - return storage_creators[self.config.exp_pool.storage_type]() + return storage_creators[self.config.exp_pool.retrieval_type]() def _create_bm25_storage(self) -> "SimpleEngine": """Creates or loads BM25 storage. @@ -116,8 +116,6 @@ class ExperienceManager(BaseModel): """ try: - from llama_index.core import VectorStoreIndex - from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ( BM25IndexConfig, @@ -133,11 +131,8 @@ class ExperienceManager(BaseModel): if not docstore_path.exists(): logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.") exps = [Experience(req="req", resp="resp")] - nodes = SimpleEngine.get_obj_nodes(exps) - embed_model = SimpleEngine._resolve_embed_model(configs=[BM25RetrieverConfig()]) - index = VectorStoreIndex(nodes, embed_model=embed_model) - retriever_configs = [BM25RetrieverConfig(index=index)] + retriever_configs = [BM25RetrieverConfig(create_index=True)] ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] storage = SimpleEngine.from_objs( diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index b3b110330..6bc8e4ad5 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -7,6 +7,7 @@ import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.embeddings.mock_embed_model import MockEmbedding from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore @@ -84,9 +85,13 @@ class RetrieverFactory(ConfigBasedFactory): def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: index = self._extract_index(config, **kwargs) nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs) + if index and not config.index: config.index = index + if not config.index and config.create_index: + config.index = VectorStoreIndex(nodes, embed_model=MockEmbedding(embed_dim=1)) + return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 5e97e60c3..5be2b050b 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -60,6 +60,11 @@ class FAISSRetrieverConfig(IndexRetrieverConfig): class BM25RetrieverConfig(IndexRetrieverConfig): """Config for BM25-based retrievers.""" + create_index: bool = Field( + default=False, + description="Indicates whether to create an index for the nodes. It is useful when you need to persist data while only using BM25.", + exclude=True, + ) _no_embedding: bool = PrivateAttr(default=True) diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py index c99d28ba6..933232031 100644 --- a/tests/metagpt/exp_pool/test_manager.py +++ b/tests/metagpt/exp_pool/test_manager.py @@ -3,7 +3,7 @@ import pytest from metagpt.config2 import Config from metagpt.configs.exp_pool_config import ( ExperiencePoolConfig, - ExperiencePoolStorageType, + ExperiencePoolRetrievalType, ) from metagpt.configs.llm_config import LLMConfig from metagpt.exp_pool.manager import Experience, ExperienceManager @@ -16,7 +16,7 @@ class TestExperienceManager: return Config( llm=LLMConfig(), exp_pool=ExperiencePoolConfig( - enable_write=True, enable_read=True, enabled=True, storage_type=ExperiencePoolStorageType.BM25 + enable_write=True, enable_read=True, enabled=True, retrieval_type=ExperiencePoolRetrievalType.BM25 ), ) @@ -96,7 +96,7 @@ class TestExperienceManager: assert exp_manager.get_exps_count() == 10 def test_resolve_storage_bm25(self, mocker, mock_config): - mock_config.exp_pool.storage_type = ExperiencePoolStorageType.BM25 + mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.BM25 mocker.patch.object(ExperienceManager, "_create_bm25_storage", return_value=mocker.MagicMock()) manager = ExperienceManager(config=mock_config) storage = manager._resolve_storage() @@ -104,7 +104,7 @@ class TestExperienceManager: assert storage is not None def test_resolve_storage_chroma(self, mocker, mock_config): - mock_config.exp_pool.storage_type = ExperiencePoolStorageType.CHROMA + mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.CHROMA mocker.patch.object(ExperienceManager, "_create_chroma_storage", return_value=mocker.MagicMock()) manager = ExperienceManager(config=mock_config) storage = manager._resolve_storage() diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 8c7a15be2..e0a174ed2 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -75,7 +75,7 @@ class TestSimpleEngine: ) # Assert - mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) + mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files, fs=None) mock_get_retriever.assert_called_once() mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1a9e9b2c9..03bdfab1d 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -1,5 +1,6 @@ import pytest +from metagpt.config2 import Config from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.embedding import RAGEmbeddingFactory @@ -12,7 +13,10 @@ class TestRAGEmbeddingFactory: @pytest.fixture def mock_config(self, mocker): - return mocker.patch("metagpt.rag.factories.embedding.config") + config = Config.default().model_copy(deep=True) + default = mocker.patch("metagpt.config2.Config.default") + default.return_value = config + return config @staticmethod def mock_openai_embedding(mocker):