diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index c237dcf69..03b645420 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -220,7 +220,9 @@ class SimpleEngine(RetrieverQueryEngine): embed_model = cls._resolve_embed_model(embed_model, retriever_configs) llm = llm or get_rag_llm() - retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model) + retriever = get_retriever( + configs=retriever_configs, nodes=nodes, embed_model=embed_model + ) # Default VectorStoreIndex(nodes, embed_model).as_retriever rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] return cls( diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index dd6261d52..a3b2268fd 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -1,6 +1,8 @@ """RAG Retriever Factory.""" +from functools import wraps + import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex @@ -28,6 +30,22 @@ from metagpt.rag.schema import ( ) +def get_or_build_index(build_index_func): + """Find index using `_extract_index` method. + + If no index is found, using build_index_func. + """ + + @wraps(build_index_func) + def wrapper(self, config, **kwargs): + index = self._extract_index(config, **kwargs) + if index is not None: + return index + return build_index_func(self, config, **kwargs) + + return wrapper + + class RetrieverFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" @@ -59,12 +77,13 @@ class RetrieverFactory(ConfigBasedFactory): return index.as_retriever() def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: - config.index = self._extract_index(config, **kwargs) or self._build_faiss_index(config, **kwargs) + config.index = self._build_faiss_index(config, **kwargs) return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: - nodes = self._extract_nodes(config, **kwargs) + index = self._extract_index(config, **kwargs) + nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs) return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) @@ -95,11 +114,13 @@ class RetrieverFactory(ConfigBasedFactory): return index + @get_or_build_index def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex: vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) return self._build_index_from_vector_store(config, vector_store, **kwargs) + @get_or_build_index def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex: db = chromadb.PersistentClient(path=str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) @@ -107,6 +128,7 @@ class RetrieverFactory(ConfigBasedFactory): return self._build_index_from_vector_store(config, vector_store, **kwargs) + @get_or_build_index def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex: vector_store = ElasticsearchStore(**config.store_config.model_dump()) diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index a70639f55..cd55a32db 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -119,3 +119,19 @@ class TestRetrieverFactory: extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index) assert extracted_index == mock_vector_store_index + + def test_get_or_build_when_get(self, mocker): + want = "existing_index" + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want) + + got = self.retriever_factory._build_es_index(None) + + assert got == want + + def test_get_or_build_when_build(self, mocker): + want = "call_build_es_index" + mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want) + + got = self.retriever_factory._build_es_index(None) + + assert got == want