diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index cdac64fa6..6aad695e7 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -29,9 +29,6 @@ class RAGIndexFactory(ConfigBasedFactory): """Key is PersistType.""" return super().get_instance(config, **kwargs) - def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding: - return self._val_from_config_or_kwargs("embed_model", config, **kwargs) - def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex: embed_model = self._extract_embed_model(config, **kwargs) @@ -59,5 +56,8 @@ class RAGIndexFactory(ConfigBasedFactory): index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) return index + def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding: + return self._val_from_config_or_kwargs("embed_model", config, **kwargs) + get_index = RAGIndexFactory().get_index diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 753041c6b..f05599e15 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -24,12 +24,12 @@ class RankerFactory(ConfigBasedFactory): return super().get_instances(configs, **kwargs) - def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: - return self._val_from_config_or_kwargs("llm", config, **kwargs) - def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank: config.llm = self._extract_llm(config, **kwargs) return LLMRerank(**config.model_dump()) + def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: + return self._val_from_config_or_kwargs("llm", config, **kwargs) + get_rankers = RankerFactory().get_rankers diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index e5e810b45..ba48c753e 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -50,21 +50,6 @@ class RetrieverFactory(ConfigBasedFactory): def _create_default(self, **kwargs) -> RAGRetriever: return self._extract_index(**kwargs).as_retriever() - def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: - return self._val_from_config_or_kwargs("index", config, **kwargs) - - def _build_index_from_vector_store( - self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs - ) -> VectorStoreIndex: - storage_context = StorageContext.from_defaults(vector_store=vector_store) - old_index = self._extract_index(config, **kwargs) - new_index = VectorStoreIndex( - nodes=list(old_index.docstore.docs.values()), - storage_context=storage_context, - embed_model=old_index._embed_model, - ) - return new_index - def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) @@ -82,5 +67,20 @@ class RetrieverFactory(ConfigBasedFactory): config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) return ChromaRetriever(**config.model_dump()) + def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: + return self._val_from_config_or_kwargs("index", config, **kwargs) + + def _build_index_from_vector_store( + self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs + ) -> VectorStoreIndex: + storage_context = StorageContext.from_defaults(vector_store=vector_store) + old_index = self._extract_index(config, **kwargs) + new_index = VectorStoreIndex( + nodes=list(old_index.docstore.docs.values()), + storage_context=storage_context, + embed_model=old_index._embed_model, + ) + return new_index + get_retriever = RetrieverFactory().get_retriever