diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index d6922ff00..9bf289038 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -1,7 +1,8 @@ """Experience Manager.""" -from llama_index.vector_stores.chroma import ChromaVectorStore -from pydantic import BaseModel, ConfigDict, model_validator +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config, config from metagpt.exp_pool.schema import ( @@ -11,27 +12,37 @@ from metagpt.exp_pool.schema import ( QueryType, ) from metagpt.logs import logger -from metagpt.rag.engines import SimpleEngine -from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig from metagpt.utils.exceptions import handle_exception +if TYPE_CHECKING: + from llama_index.vector_stores.chroma import ChromaVectorStore + class ExperienceManager(BaseModel): """ExperienceManager manages the lifecycle of experiences, including CRUD and optimization. Args: config (Config): Configuration for managing experiences. - storage (SimpleEngine): Engine to handle the storage and retrieval of experiences. + _storage (SimpleEngine): Engine to handle the storage and retrieval of experiences. + _vector_store (ChromaVectorStore): The actual place where vectors are stored. """ model_config = ConfigDict(arbitrary_types_allowed=True) config: Config = config - storage: SimpleEngine = None - @model_validator(mode="after") - def initialize(self): - if self.storage is None: + _storage: Any = None + _vector_store: Any = None + + @property + def storage(self): + if self._storage is None: + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + retriever_configs = [ ChromaRetrieverConfig( persist_path=self.config.exp_pool.persist_path, @@ -41,14 +52,19 @@ class ExperienceManager(BaseModel): ] ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] - self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + self._storage: SimpleEngine = SimpleEngine.from_objs( + retriever_configs=retriever_configs, ranker_configs=ranker_configs + ) + logger.debug(f"exp_pool config: {self.config.exp_pool}") - logger.debug(f"exp_pool config: {self.config.exp_pool}") - return self + return self._storage @property - def vector_store(self) -> ChromaVectorStore: - return self.storage._retriever._vector_store + def vector_store(self): + if not self._vector_store: + self._vector_store: ChromaVectorStore = self.storage._retriever._vector_store + + return self._vector_store @handle_exception def create_exp(self, exp: Experience):