diff --git a/config/config2.example.yaml b/config/config2.example.yaml index ba480d984..59406833a 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -79,6 +79,7 @@ exp_pool: enable_read: false enable_write: false persist_path: .chroma_exp_data # The directory. + storage_type: bm25 # Default is bm25, can be set to chroma for vector storage, which requires setting up embedding. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/examples/exp_pool/README.md b/examples/exp_pool/README.md index d0b49f2ad..37e7853f8 100644 --- a/examples/exp_pool/README.md +++ b/examples/exp_pool/README.md @@ -3,7 +3,7 @@ # Experience Pool ## Prerequisites - Ensure the RAG module is installed: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html - Set embedding: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html -- Set both `enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml` +- Set `enabled`、`enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml` ## Example Files diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index e2872179f..11669a301 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -1,8 +1,15 @@ +from enum import Enum + from pydantic import Field from metagpt.utils.yaml_model import YamlModel +class ExperiencePoolStorageType(Enum): + BM25 = "bm25" + CHROMA = "chroma" + + class ExperiencePoolConfig(YamlModel): enabled: bool = Field( default=False, @@ -11,3 +18,6 @@ class ExperiencePoolConfig(YamlModel): enable_read: bool = Field(default=False, description="Enable to read from experience pool.") enable_write: bool = Field(default=False, description="Enable to write to experience pool.") persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.") + storage_type: ExperiencePoolStorageType = Field( + default=ExperiencePoolStorageType.BM25, description="The storage type for experience pool." + ) diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 5f4d71edc..8388eabb7 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -1,10 +1,12 @@ """Experience Manager.""" +from pathlib import Path from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field from metagpt.config2 import Config +from metagpt.configs.exp_pool_config import ExperiencePoolStorageType from metagpt.exp_pool.schema import ( DEFAULT_COLLECTION_NAME, DEFAULT_SIMILARITY_TOP_K, @@ -15,7 +17,7 @@ from metagpt.logs import logger from metagpt.utils.exceptions import handle_exception if TYPE_CHECKING: - from llama_index.vector_stores.chroma import ChromaVectorStore + from metagpt.rag.engines import SimpleEngine class ExperienceManager(BaseModel): @@ -32,40 +34,16 @@ class ExperienceManager(BaseModel): config: Config = Field(default_factory=Config.default) _storage: Any = None - _vector_store: Any = None @property def storage(self): if self._storage is None: - try: - from metagpt.rag.engines import SimpleEngine - from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig - except ImportError: - raise ImportError("To use the experience pool, you need to install the rag module.") - - retriever_configs = [ - ChromaRetrieverConfig( - persist_path=self.config.exp_pool.persist_path, - collection_name=DEFAULT_COLLECTION_NAME, - similarity_top_k=DEFAULT_SIMILARITY_TOP_K, - ) - ] - ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] - - self._storage: SimpleEngine = SimpleEngine.from_objs( - retriever_configs=retriever_configs, ranker_configs=ranker_configs - ) logger.info(f"exp_pool config: {self.config.exp_pool}") + self._storage = self._resolve_storage() + return self._storage - @property - def vector_store(self): - if not self._vector_store: - self._vector_store: ChromaVectorStore = self.storage._retriever._vector_store - - return self._vector_store - @handle_exception def create_exp(self, exp: Experience): """Adds an experience to the storage if writing is enabled. @@ -78,6 +56,7 @@ class ExperienceManager(BaseModel): return self.storage.add_objs([exp]) + self.storage.persist(self.config.exp_pool.persist_path) @handle_exception(default_return=[]) async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]: @@ -110,7 +89,97 @@ class ExperienceManager(BaseModel): def get_exps_count(self) -> int: """Get the total number of experiences.""" - return self.vector_store._collection.count() + return self.storage.count() + + def _resolve_storage(self) -> "SimpleEngine": + """Selects the appropriate storage creation method based on the configured storage type.""" + + storage_creators = { + ExperiencePoolStorageType.BM25: self._create_bm25_storage, + ExperiencePoolStorageType.CHROMA: self._create_chroma_storage, + } + + return storage_creators[self.config.exp_pool.storage_type]() + + def _create_bm25_storage(self) -> "SimpleEngine": + """Creates or loads BM25 storage. + + This function attempts to create a new BM25 storage if the specified + document store path does not exist. If the path exists, it loads the + existing BM25 storage. + + Returns: + SimpleEngine: An instance of SimpleEngine configured with BM25 storage. + + Raises: + ImportError: If required modules are not installed. + """ + + try: + from llama_index.core import VectorStoreIndex + + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ( + BM25IndexConfig, + BM25RetrieverConfig, + LLMRankerConfig, + ) + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + + persist_path = Path(self.config.exp_pool.persist_path) + docstore_path = persist_path / "docstore.json" + + if not docstore_path.exists(): + logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.") + exps = [Experience(req="req", resp="resp")] + nodes = SimpleEngine.get_obj_nodes(exps) + embed_model = SimpleEngine._resolve_embed_model(configs=[BM25RetrieverConfig()]) + index = VectorStoreIndex(nodes, embed_model=embed_model) + + retriever_configs = [BM25RetrieverConfig(index=index)] + ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] + + storage = SimpleEngine.from_objs( + objs=exps, retriever_configs=retriever_configs, ranker_configs=ranker_configs + ) + return storage + + logger.debug(f"Path `{docstore_path}` exists, try to load bm25 storage.") + storage = SimpleEngine.from_index( + BM25IndexConfig(persist_path=persist_path), retriever_configs=[BM25RetrieverConfig()] + ) + + return storage + + def _create_chroma_storage(self) -> "SimpleEngine": + """Creates Chroma storage. + + Returns: + SimpleEngine: An instance of SimpleEngine configured with Chroma storage. + + Raises: + ImportError: If required modules are not installed. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + + retriever_configs = [ + ChromaRetrieverConfig( + persist_path=self.config.exp_pool.persist_path, + collection_name=DEFAULT_COLLECTION_NAME, + similarity_top_k=DEFAULT_SIMILARITY_TOP_K, + ) + ] + ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] + + storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + return storage _exp_manager = None diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 8a9ccaffd..be4c3daf5 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -37,7 +37,11 @@ from metagpt.rag.factories import ( get_retriever, ) from metagpt.rag.interface import NoEmbedding, RAGObject -from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever +from metagpt.rag.retrievers.base import ( + ModifiableRAGRetriever, + PersistableRAGRetriever, + QueryableRAGRetriever, +) from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( BaseIndexConfig, @@ -144,7 +148,7 @@ class SimpleEngine(RetrieverQueryEngine): if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs): raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") - nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + nodes = cls.get_obj_nodes(objs) return cls._from_nodes( nodes=nodes, @@ -201,7 +205,7 @@ class SimpleEngine(RetrieverQueryEngine): """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" self._ensure_retriever_modifiable() - nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + nodes = self.get_obj_nodes(objs) self._save_nodes(nodes) def persist(self, persist_dir: Union[str, os.PathLike], **kwargs): @@ -210,6 +214,18 @@ class SimpleEngine(RetrieverQueryEngine): self._persist(str(persist_dir), **kwargs) + def count(self) -> int: + """Count.""" + self._ensure_retriever_queryable() + + return self._retriever.query_total_count() + + @staticmethod + def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]: + """Converts a list of RAGObjects to a list of ObjectNodes.""" + + return [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + @classmethod def _from_nodes( cls, @@ -258,6 +274,9 @@ class SimpleEngine(RetrieverQueryEngine): def _ensure_retriever_persistable(self): self._ensure_retriever_of_type(PersistableRAGRetriever) + def _ensure_retriever_queryable(self): + self._ensure_retriever_of_type(QueryableRAGRetriever) + def _ensure_retriever_of_type(self, required_type: BaseRetriever): """Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever. diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 1460e131b..b3b110330 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -84,6 +84,8 @@ class RetrieverFactory(ConfigBasedFactory): def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: index = self._extract_index(config, **kwargs) nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs) + if index and not config.index: + config.index = index return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index a7b836833..5bd04adca 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -45,3 +45,17 @@ class PersistableRAGRetriever(RAGRetriever): @abstractmethod def persist(self, persist_dir: str, **kwargs) -> None: """To support persist, must inplement this func""" + + +class QueryableRAGRetriever(RAGRetriever): + """Support querying total count.""" + + @classmethod + def __subclasshook__(cls, C): + if cls is QueryableRAGRetriever: + return check_methods(C, "query_total_count") + return NotImplemented + + @abstractmethod + def query_total_count(self) -> int: + """To support querying total count, must implement this func""" diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index dc75d87b0..ace1bb86c 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -47,3 +47,8 @@ class DynamicBM25Retriever(BM25Retriever): """Support persist.""" if self._index: self._index.storage_context.persist(persist_dir) + + def query_total_count(self) -> int: + """Support query total count.""" + + return len(self._nodes) diff --git a/metagpt/rag/retrievers/chroma_retriever.py b/metagpt/rag/retrievers/chroma_retriever.py index d41f375e4..6c466e49f 100644 --- a/metagpt/rag/retrievers/chroma_retriever.py +++ b/metagpt/rag/retrievers/chroma_retriever.py @@ -2,6 +2,7 @@ from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.schema import BaseNode +from llama_index.vector_stores.chroma import ChromaVectorStore class ChromaRetriever(VectorIndexRetriever): @@ -15,3 +16,10 @@ class ChromaRetriever(VectorIndexRetriever): """Support persist. Chromadb automatically saves, so there is no need to implement.""" + + def query_total_count(self) -> int: + """Support query total count.""" + + vector_store: ChromaVectorStore = self._vector_store + + return vector_store._collection.count() diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py index 4d298a44e..c99d28ba6 100644 --- a/tests/metagpt/exp_pool/test_manager.py +++ b/tests/metagpt/exp_pool/test_manager.py @@ -1,7 +1,10 @@ import pytest from metagpt.config2 import Config -from metagpt.configs.exp_pool_config import ExperiencePoolConfig +from metagpt.configs.exp_pool_config import ( + ExperiencePoolConfig, + ExperiencePoolStorageType, +) from metagpt.configs.llm_config import LLMConfig from metagpt.exp_pool.manager import Experience, ExperienceManager from metagpt.exp_pool.schema import QueryType @@ -10,17 +13,19 @@ from metagpt.exp_pool.schema import QueryType class TestExperienceManager: @pytest.fixture def mock_config(self): - return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, enabled=True)) + return Config( + llm=LLMConfig(), + exp_pool=ExperiencePoolConfig( + enable_write=True, enable_read=True, enabled=True, storage_type=ExperiencePoolStorageType.BM25 + ), + ) @pytest.fixture def mock_storage(self, mocker): engine = mocker.MagicMock() engine.add_objs = mocker.MagicMock() engine.aretrieve = mocker.AsyncMock(return_value=[]) - engine._retriever = mocker.MagicMock() - engine._retriever._vector_store = mocker.MagicMock() - engine._retriever._vector_store._collection = mocker.MagicMock() - engine._retriever._vector_store._collection.count = mocker.MagicMock(return_value=10) + engine.count = mocker.MagicMock(return_value=10) return engine @pytest.fixture @@ -29,8 +34,33 @@ class TestExperienceManager: manager._storage = mock_storage return manager - def test_vector_store_property(self, exp_manager): - assert exp_manager.vector_store == exp_manager.storage._retriever._vector_store + def test_storage_property(self, exp_manager, mock_storage): + assert exp_manager.storage == mock_storage + + def test_storage_property_initialization(self, mocker, mock_config): + mocker.patch.object(ExperienceManager, "_resolve_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + assert manager._storage is None + _ = manager.storage + assert manager._storage is not None + + def test_create_exp_write_disabled(self, exp_manager, mock_config): + mock_config.exp_pool.enable_write = False + exp = Experience(req="test", resp="response") + exp_manager.create_exp(exp) + exp_manager.storage.add_objs.assert_not_called() + + def test_create_exp_write_enabled(self, exp_manager): + exp = Experience(req="test", resp="response") + exp_manager.create_exp(exp) + exp_manager.storage.add_objs.assert_called_once_with([exp]) + exp_manager.storage.persist.assert_called_once_with(exp_manager.config.exp_pool.persist_path) + + @pytest.mark.asyncio + async def test_query_exps_read_disabled(self, exp_manager, mock_config): + mock_config.exp_pool.enable_read = False + result = await exp_manager.query_exps("query") + assert result == [] @pytest.mark.asyncio async def test_query_exps_with_exact_match(self, exp_manager, mocker): @@ -65,14 +95,37 @@ class TestExperienceManager: def test_get_exps_count(self, exp_manager): assert exp_manager.get_exps_count() == 10 - def test_create_exp_write_disabled(self, exp_manager, mock_config): - mock_config.exp_pool.enable_write = False - exp = Experience(req="test", resp="response") - exp_manager.create_exp(exp) - exp_manager.storage.add_objs.assert_not_called() + def test_resolve_storage_bm25(self, mocker, mock_config): + mock_config.exp_pool.storage_type = ExperiencePoolStorageType.BM25 + mocker.patch.object(ExperienceManager, "_create_bm25_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._resolve_storage() + manager._create_bm25_storage.assert_called_once() + assert storage is not None - @pytest.mark.asyncio - async def test_query_exps_read_disabled(self, exp_manager, mock_config): - mock_config.exp_pool.enable_read = False - result = await exp_manager.query_exps("query") - assert result == [] + def test_resolve_storage_chroma(self, mocker, mock_config): + mock_config.exp_pool.storage_type = ExperiencePoolStorageType.CHROMA + mocker.patch.object(ExperienceManager, "_create_chroma_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._resolve_storage() + manager._create_chroma_storage.assert_called_once() + assert storage is not None + + def test_create_bm25_storage(self, mocker, mock_config): + mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.engines.SimpleEngine.from_index", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.engines.SimpleEngine.get_obj_nodes", return_value=[]) + mocker.patch("metagpt.rag.engines.SimpleEngine._resolve_embed_model", return_value=mocker.MagicMock()) + mocker.patch("llama_index.core.VectorStoreIndex", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.schema.BM25RetrieverConfig", return_value=mocker.MagicMock()) + mocker.patch("pathlib.Path.exists", return_value=False) + + manager = ExperienceManager(config=mock_config) + storage = manager._create_bm25_storage() + assert storage is not None + + def test_create_chroma_storage(self, mocker, mock_config): + mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._create_chroma_storage() + assert storage is not None