From dd965a21493a0ba1cab6ebfa13b2a3a7229ca74a Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 7 Feb 2024 15:20:21 +0800 Subject: [PATCH] simplify rag factory --- metagpt/rag/engines/simple.py | 13 ++- metagpt/rag/factory.py | 109 ++++++++++++++++++ metagpt/rag/rankers/__init__.py | 5 - metagpt/rag/rankers/factory.py | 37 ------ metagpt/rag/retrievers/__init__.py | 3 +- metagpt/rag/retrievers/base.py | 11 ++ metagpt/rag/retrievers/factory.py | 62 ---------- tests/metagpt/rag/engine/test_simple.py | 4 +- .../rag/rankers/test_ranker_factory.py | 47 -------- ...t_retriever_factory.py => test_factory.py} | 59 +++++++++- 10 files changed, 183 insertions(+), 167 deletions(-) create mode 100644 metagpt/rag/factory.py delete mode 100644 metagpt/rag/rankers/factory.py delete mode 100644 metagpt/rag/retrievers/factory.py delete mode 100644 tests/metagpt/rag/rankers/test_ranker_factory.py rename tests/metagpt/rag/{retrievers/test_retriever_factory.py => test_factory.py} (59%) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index c4e3b6f31..e71cfc439 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -14,10 +14,9 @@ from llama_index.query_engine import RetrieverQueryEngine from llama_index.response_synthesizers import BaseSynthesizer from llama_index.schema import NodeWithScore, QueryBundle, QueryType +from metagpt.rag.factory import get_rankers, get_retriever from metagpt.rag.llm import get_default_llm -from metagpt.rag.rankers import get_rankers -from metagpt.rag.retrievers import get_retriever -from metagpt.rag.retrievers.base import RAGRetriever +from metagpt.rag.retrievers.base import ModifiableRAGRetriever from metagpt.rag.schema import RankerConfigType, RetrieverConfigType from metagpt.utils.embedding import get_embedding @@ -93,8 +92,10 @@ class SimpleEngine(RetrieverQueryEngine): return await super().aretrieve(query_bundle) def add_docs(self, input_files: list[str]): - """Add docs to retriever""" + """Add docs to retriever. retriever must has add_nodes func""" + if not isinstance(self.retriever, ModifiableRAGRetriever): + raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}") + documents = SimpleDirectoryReader(input_files=input_files).load_data() nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents) - retriever: RAGRetriever = self.retriever - retriever.add_nodes(nodes) + self.retriever.add_nodes(nodes) diff --git a/metagpt/rag/factory.py b/metagpt/rag/factory.py new file mode 100644 index 000000000..4076e43c4 --- /dev/null +++ b/metagpt/rag/factory.py @@ -0,0 +1,109 @@ +"""Factory for creating retriever, ranker""" +from typing import Any, Callable + +import faiss +from llama_index import ServiceContext, StorageContext, VectorStoreIndex +from llama_index.indices.base import BaseIndex +from llama_index.postprocessor import LLMRerank +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.vector_stores.faiss import FaissVectorStore + +from metagpt.rag.retrievers.base import RAGRetriever +from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever +from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever +from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever +from metagpt.rag.schema import ( + BM25RetrieverConfig, + FAISSRetrieverConfig, + LLMRankerConfig, + RankerConfigType, + RetrieverConfigType, +) + + +class BaseFactory: + """ + A base factory class for creating instances based on provided configurations. + It uses a registry of creator functions mapped to configuration types to instantiate objects dynamically. + """ + + def __init__(self, creators: dict[Any, Callable]): + """ + Creators is a dictionary mapping configuration types to creator functions. + The first arg of Creator function should be config. + """ + self.creators = creators + + def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]: + if not configs: + return [self._default_instance(**kwargs)] + + return [self._get_instance(config, **kwargs) for config in configs] + + def _get_instance(self, config: Any, **kwargs) -> Any: + create_func = self.creators.get(type(config)) + if create_func: + return create_func(config, **kwargs) + + raise ValueError(f"Unknown config: {config}") + + def _default_instance(self, **kwargs) -> Any: + raise NotImplementedError("This method should be implemented by subclasses.") + + +class RetrieverFactory(BaseFactory): + def __init__(self): + creators = { + FAISSRetrieverConfig: self._create_faiss_retriever, + BM25RetrieverConfig: self._create_bm25_retriever, + } + super().__init__(creators) + + def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever: + """Creates and returns a retriever instance based on the provided configurations.""" + retrievers = super().get_instances(configs, index=index) + + return ( + SimpleHybridRetriever(*retrievers, service_context=index.service_context) + if len(retrievers) > 1 + else retrievers[0] + ) + + def _default_instance(self, index: BaseIndex) -> RAGRetriever: + return index.as_retriever() + + def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex, **kwargs) -> FAISSRetriever: + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + vector_index = VectorStoreIndex( + nodes=list(index.docstore.docs.values()), + storage_context=storage_context, + service_context=index.service_context, + ) + return FAISSRetriever(**config.model_dump(), index=vector_index) + + def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex, **kwargs) -> DynamicBM25Retriever: + return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index) + + +class RankerFactory(BaseFactory): + def __init__(self): + creators = { + LLMRankerConfig: self._create_llm_ranker, + } + super().__init__(creators) + + def get_rankers( + self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None + ) -> list[BaseNodePostprocessor]: + return super().get_instances(configs, service_context=service_context) + + def _default_instance(self, service_context: ServiceContext = None) -> LLMRerank: + return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context) + + def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None, **kwargs) -> LLMRerank: + return LLMRerank(**config.model_dump(), service_context=service_context) + + +get_retriever = RetrieverFactory().get_retriever +get_rankers = RankerFactory().get_rankers diff --git a/metagpt/rag/rankers/__init__.py b/metagpt/rag/rankers/__init__.py index bb14007ba..82743487c 100644 --- a/metagpt/rag/rankers/__init__.py +++ b/metagpt/rag/rankers/__init__.py @@ -1,6 +1 @@ """Rankers init""" - -from metagpt.rag.rankers.factory import get_rankers - - -__all__ = ["get_rankers"] diff --git a/metagpt/rag/rankers/factory.py b/metagpt/rag/rankers/factory.py deleted file mode 100644 index b139fdd92..000000000 --- a/metagpt/rag/rankers/factory.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Rankers Factory""" -from llama_index import ServiceContext -from llama_index.postprocessor import LLMRerank -from llama_index.postprocessor.types import BaseNodePostprocessor - -from metagpt.rag.schema import LLMRankerConfig, RankerConfigType - - -class RankerFactory: - def __init__(self): - self.ranker_creators = { - LLMRankerConfig: self._create_llm_ranker, - } - - def get_rankers( - self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None - ) -> list[BaseNodePostprocessor]: - if not configs: - return [self._default_ranker(service_context)] - - return [self._get_ranker(config, service_context) for config in configs] - - def _default_ranker(self, service_context: ServiceContext = None) -> LLMRerank: - return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context) - - def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None) -> BaseNodePostprocessor: - create_func = self.ranker_creators.get(type(config)) - if create_func: - return create_func(config, service_context) - - raise ValueError(f"Unknown ranker config: {config}") - - def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None) -> LLMRerank: - return LLMRerank(**config.model_dump(), service_context=service_context) - - -get_rankers = RankerFactory().get_rankers diff --git a/metagpt/rag/retrievers/__init__.py b/metagpt/rag/retrievers/__init__.py index 88cb4cc77..7f4371423 100644 --- a/metagpt/rag/retrievers/__init__.py +++ b/metagpt/rag/retrievers/__init__.py @@ -1,6 +1,5 @@ """Retrievers init""" from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever -from metagpt.rag.retrievers.factory import get_retriever -__all__ = ["SimpleHybridRetriever", "get_retriever"] +__all__ = ["SimpleHybridRetriever"] diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index 97590a138..5d509f0e2 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -17,5 +17,16 @@ class RAGRetriever(BaseRetriever): def _retrieve(self, query: QueryType) -> list[NodeWithScore]: """Retrieve nodes""" + +class ModifiableRAGRetriever(RAGRetriever): + """Support modification.""" + + @classmethod + def __subclasshook__(cls, C): + if any("add_nodes" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + + @abstractmethod def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """To support add docs, must inplement this func""" diff --git a/metagpt/rag/retrievers/factory.py b/metagpt/rag/retrievers/factory.py deleted file mode 100644 index c2dcb2725..000000000 --- a/metagpt/rag/retrievers/factory.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Retriever Factory""" -import faiss -from llama_index import StorageContext, VectorStoreIndex -from llama_index.indices.base import BaseIndex -from llama_index.vector_stores.faiss import FaissVectorStore - -from metagpt.rag.retrievers.base import RAGRetriever -from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever -from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever -from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever -from metagpt.rag.schema import ( - BM25RetrieverConfig, - FAISSRetrieverConfig, - RetrieverConfigType, -) - - -class RetrieverFactory: - def __init__(self): - self.retriever_creators = { - FAISSRetrieverConfig: self._create_faiss_retriever, - BM25RetrieverConfig: self._create_bm25_retriever, - } - - def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever: - """Creates and returns a retriever instance based on the provided configurations.""" - if not configs: - return self._default_retriever(index) - - retrievers = [self._get_retriever(index, config) for config in configs] - - return ( - SimpleHybridRetriever(*retrievers, service_context=index.service_context) - if len(retrievers) > 1 - else retrievers[0] - ) - - def _default_retriever(self, index: BaseIndex) -> RAGRetriever: - return index.as_retriever() - - def _get_retriever(self, index: BaseIndex, config: RetrieverConfigType) -> RAGRetriever: - create_func = self.retriever_creators.get(type(config)) - if create_func: - return create_func(index, config) - - raise ValueError(f"Unknown retriever config: {config}") - - def _create_faiss_retriever(self, index: BaseIndex, config: FAISSRetrieverConfig): - vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) - storage_context = StorageContext.from_defaults(vector_store=vector_store) - vector_index = VectorStoreIndex( - nodes=list(index.docstore.docs.values()), - storage_context=storage_context, - service_context=index.service_context, - ) - return FAISSRetriever(vector_index, **config.model_dump()) - - def _create_bm25_retriever(self, index: BaseIndex, config: BM25RetrieverConfig): - return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index) - - -get_retriever = RetrieverFactory().get_retriever diff --git a/tests/metagpt/rag/engine/test_simple.py b/tests/metagpt/rag/engine/test_simple.py index 4d047b075..ceec4d63a 100644 --- a/tests/metagpt/rag/engine/test_simple.py +++ b/tests/metagpt/rag/engine/test_simple.py @@ -2,7 +2,7 @@ import pytest from llama_index import VectorStoreIndex from metagpt.rag.engines import SimpleEngine -from metagpt.rag.retrievers.base import RAGRetriever +from metagpt.rag.retrievers.base import ModifiableRAGRetriever class TestSimpleEngine: @@ -99,7 +99,7 @@ class TestSimpleEngine: mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"] - mock_retriever = mocker.MagicMock(spec=RAGRetriever) + mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever) mock_index = mocker.MagicMock(spec=VectorStoreIndex) mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"] diff --git a/tests/metagpt/rag/rankers/test_ranker_factory.py b/tests/metagpt/rag/rankers/test_ranker_factory.py deleted file mode 100644 index ec335cee2..000000000 --- a/tests/metagpt/rag/rankers/test_ranker_factory.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -from llama_index import ServiceContext -from llama_index.postprocessor import LLMRerank - -from metagpt.rag.rankers.factory import RankerFactory -from metagpt.rag.schema import LLMRankerConfig - - -class TestRankerFactory: - @pytest.fixture - def mock_service_context(self, mocker): - return mocker.MagicMock(spec=ServiceContext) - - def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context): - # Setup - factory = RankerFactory() - - # Execute - rankers = factory.get_rankers(service_context=mock_service_context) - - # Assertions - assert len(rankers) == 1 - assert isinstance(rankers[0], LLMRerank) - - def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context): - # Setup - config = LLMRankerConfig(top_n=3) - factory = RankerFactory() - - # Execute - rankers = factory.get_rankers(configs=[config], service_context=mock_service_context) - - # Assertions - assert len(rankers) == 1 - assert isinstance(rankers[0], LLMRerank) - assert rankers[0].top_n == 3 - - def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context): - # Mock - mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型 - - # Setup - factory = RankerFactory() - - # Execute & Assertions - with pytest.raises(ValueError): - factory.get_rankers(configs=[mock_config], service_context=mock_service_context) diff --git a/tests/metagpt/rag/retrievers/test_retriever_factory.py b/tests/metagpt/rag/test_factory.py similarity index 59% rename from tests/metagpt/rag/retrievers/test_retriever_factory.py rename to tests/metagpt/rag/test_factory.py index dc69a49fc..70e0809a9 100644 --- a/tests/metagpt/rag/retrievers/test_retriever_factory.py +++ b/tests/metagpt/rag/test_factory.py @@ -1,12 +1,18 @@ import pytest +from llama_index import ServiceContext from llama_index.indices.base import BaseIndex +from llama_index.postprocessor import LLMRerank +from metagpt.rag.factory import RankerFactory, RetrieverFactory from metagpt.rag.retrievers.base import RAGRetriever from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever -from metagpt.rag.retrievers.factory import RetrieverFactory from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever -from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig +from metagpt.rag.schema import ( + BM25RetrieverConfig, + FAISSRetrieverConfig, + LLMRankerConfig, +) class TestRetrieverFactory: @@ -28,20 +34,20 @@ class TestRetrieverFactory: @pytest.fixture def mock_faiss_vector_store(self, mocker): - return mocker.patch("metagpt.rag.retrievers.factory.FaissVectorStore") + return mocker.patch("metagpt.rag.factory.FaissVectorStore") @pytest.fixture def mock_storage_context(self, mocker): - return mocker.patch("metagpt.rag.retrievers.factory.StorageContext") + return mocker.patch("metagpt.rag.factory.StorageContext") @pytest.fixture def mock_vector_store_index(self, mocker): - return mocker.patch("metagpt.rag.retrievers.factory.VectorStoreIndex") + return mocker.patch("metagpt.rag.factory.VectorStoreIndex") @pytest.fixture def mock_dynamic_bm25_retriever(self, mocker): mock = mocker.MagicMock(spec=DynamicBM25Retriever) - return mocker.patch("metagpt.rag.retrievers.factory.DynamicBM25Retriever", mock) + return mocker.patch("metagpt.rag.factory.DynamicBM25Retriever", mock) def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index): factory = RetrieverFactory() @@ -81,3 +87,44 @@ class TestRetrieverFactory: factory = RetrieverFactory() with pytest.raises(ValueError): factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config]) + + +class TestRankerFactory: + @pytest.fixture + def mock_service_context(self, mocker): + return mocker.MagicMock(spec=ServiceContext) + + def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context): + # Setup + factory = RankerFactory() + + # Execute + rankers = factory.get_rankers(service_context=mock_service_context) + + # Assertions + assert len(rankers) == 1 + assert isinstance(rankers[0], LLMRerank) + + def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context): + # Setup + config = LLMRankerConfig(top_n=3) + factory = RankerFactory() + + # Execute + rankers = factory.get_rankers(configs=[config], service_context=mock_service_context) + + # Assertions + assert len(rankers) == 1 + assert isinstance(rankers[0], LLMRerank) + assert rankers[0].top_n == 3 + + def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context): + # Mock + mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型 + + # Setup + factory = RankerFactory() + + # Execute & Assertions + with pytest.raises(ValueError): + factory.get_rankers(configs=[mock_config], service_context=mock_service_context)