lazy import rag

This commit is contained in:
seehi 2024-07-16 17:28:54 +08:00
parent 8b9e992b56
commit 0b604c42b5

View file

@ -1,7 +1,8 @@
"""Experience Manager."""
from llama_index.vector_stores.chroma import ChromaVectorStore
from pydantic import BaseModel, ConfigDict, model_validator
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict
from metagpt.config2 import Config, config
from metagpt.exp_pool.schema import (
@ -11,27 +12,37 @@ from metagpt.exp_pool.schema import (
QueryType,
)
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
from metagpt.utils.exceptions import handle_exception
if TYPE_CHECKING:
from llama_index.vector_stores.chroma import ChromaVectorStore
class ExperienceManager(BaseModel):
"""ExperienceManager manages the lifecycle of experiences, including CRUD and optimization.
Args:
config (Config): Configuration for managing experiences.
storage (SimpleEngine): Engine to handle the storage and retrieval of experiences.
_storage (SimpleEngine): Engine to handle the storage and retrieval of experiences.
_vector_store (ChromaVectorStore): The actual place where vectors are stored.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
config: Config = config
storage: SimpleEngine = None
@model_validator(mode="after")
def initialize(self):
if self.storage is None:
_storage: Any = None
_vector_store: Any = None
@property
def storage(self):
if self._storage is None:
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
except ImportError:
raise ImportError("To use the experience pool, you need to install the rag module.")
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
@ -41,14 +52,19 @@ class ExperienceManager(BaseModel):
]
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
self._storage: SimpleEngine = SimpleEngine.from_objs(
retriever_configs=retriever_configs, ranker_configs=ranker_configs
)
logger.debug(f"exp_pool config: {self.config.exp_pool}")
logger.debug(f"exp_pool config: {self.config.exp_pool}")
return self
return self._storage
@property
def vector_store(self) -> ChromaVectorStore:
return self.storage._retriever._vector_store
def vector_store(self):
if not self._vector_store:
self._vector_store: ChromaVectorStore = self.storage._retriever._vector_store
return self._vector_store
@handle_exception
def create_exp(self, exp: Experience):