diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index 3e0c13c25..504faafc6 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -23,11 +23,11 @@ class RAGIndexFactory(ConfigFactory): """Key is PersistType.""" return super().get_instance(config, **kwargs) - def extract_embed_model(self, config, **kwargs) -> BaseEmbedding: + def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding: return self._val_from_config_or_kwargs("embed_model", config, **kwargs) def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex: - embed_model = self.extract_embed_model(config, **kwargs) + embed_model = self._extract_embed_model(config, **kwargs) vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path)) storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path) @@ -35,7 +35,7 @@ class RAGIndexFactory(ConfigFactory): return index def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: - embed_model = self.extract_embed_model(config, **kwargs) + embed_model = self._extract_embed_model(config, **kwargs) db = chromadb.PersistentClient(str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name)