simplify rag factory

This commit is contained in:
seehi 2024-02-07 15:20:21 +08:00
parent a4c095300c
commit dd965a2149
10 changed files with 183 additions and 167 deletions

View file

@ -14,10 +14,9 @@ from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import BaseSynthesizer
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
from metagpt.rag.factory import get_rankers, get_retriever
from metagpt.rag.llm import get_default_llm
from metagpt.rag.rankers import get_rankers
from metagpt.rag.retrievers import get_retriever
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
from metagpt.utils.embedding import get_embedding
@ -93,8 +92,10 @@ class SimpleEngine(RetrieverQueryEngine):
return await super().aretrieve(query_bundle)
def add_docs(self, input_files: list[str]):
"""Add docs to retriever"""
"""Add docs to retriever. retriever must has add_nodes func"""
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}")
documents = SimpleDirectoryReader(input_files=input_files).load_data()
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
retriever: RAGRetriever = self.retriever
retriever.add_nodes(nodes)
self.retriever.add_nodes(nodes)

109
metagpt/rag/factory.py Normal file
View file

@ -0,0 +1,109 @@
"""Factory for creating retriever, ranker"""
from typing import Any, Callable
import faiss
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
from llama_index.indices.base import BaseIndex
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
RankerConfigType,
RetrieverConfigType,
)
class BaseFactory:
"""
A base factory class for creating instances based on provided configurations.
It uses a registry of creator functions mapped to configuration types to instantiate objects dynamically.
"""
def __init__(self, creators: dict[Any, Callable]):
"""
Creators is a dictionary mapping configuration types to creator functions.
The first arg of Creator function should be config.
"""
self.creators = creators
def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]:
if not configs:
return [self._default_instance(**kwargs)]
return [self._get_instance(config, **kwargs) for config in configs]
def _get_instance(self, config: Any, **kwargs) -> Any:
create_func = self.creators.get(type(config))
if create_func:
return create_func(config, **kwargs)
raise ValueError(f"Unknown config: {config}")
def _default_instance(self, **kwargs) -> Any:
raise NotImplementedError("This method should be implemented by subclasses.")
class RetrieverFactory(BaseFactory):
def __init__(self):
creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
}
super().__init__(creators)
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations."""
retrievers = super().get_instances(configs, index=index)
return (
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
if len(retrievers) > 1
else retrievers[0]
)
def _default_instance(self, index: BaseIndex) -> RAGRetriever:
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(
nodes=list(index.docstore.docs.values()),
storage_context=storage_context,
service_context=index.service_context,
)
return FAISSRetriever(**config.model_dump(), index=vector_index)
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex, **kwargs) -> DynamicBM25Retriever:
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
class RankerFactory(BaseFactory):
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
}
super().__init__(creators)
def get_rankers(
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
return super().get_instances(configs, service_context=service_context)
def _default_instance(self, service_context: ServiceContext = None) -> LLMRerank:
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None, **kwargs) -> LLMRerank:
return LLMRerank(**config.model_dump(), service_context=service_context)
get_retriever = RetrieverFactory().get_retriever
get_rankers = RankerFactory().get_rankers

View file

@ -1,6 +1 @@
"""Rankers init"""
from metagpt.rag.rankers.factory import get_rankers
__all__ = ["get_rankers"]

View file

@ -1,37 +0,0 @@
"""Rankers Factory"""
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
from metagpt.rag.schema import LLMRankerConfig, RankerConfigType
class RankerFactory:
def __init__(self):
self.ranker_creators = {
LLMRankerConfig: self._create_llm_ranker,
}
def get_rankers(
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
if not configs:
return [self._default_ranker(service_context)]
return [self._get_ranker(config, service_context) for config in configs]
def _default_ranker(self, service_context: ServiceContext = None) -> LLMRerank:
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None) -> BaseNodePostprocessor:
create_func = self.ranker_creators.get(type(config))
if create_func:
return create_func(config, service_context)
raise ValueError(f"Unknown ranker config: {config}")
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None) -> LLMRerank:
return LLMRerank(**config.model_dump(), service_context=service_context)
get_rankers = RankerFactory().get_rankers

View file

@ -1,6 +1,5 @@
"""Retrievers init"""
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.factory import get_retriever
__all__ = ["SimpleHybridRetriever", "get_retriever"]
__all__ = ["SimpleHybridRetriever"]

View file

@ -17,5 +17,16 @@ class RAGRetriever(BaseRetriever):
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Retrieve nodes"""
class ModifiableRAGRetriever(RAGRetriever):
"""Support modification."""
@classmethod
def __subclasshook__(cls, C):
if any("add_nodes" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented
@abstractmethod
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""To support add docs, must inplement this func"""

View file

@ -1,62 +0,0 @@
"""Retriever Factory"""
import faiss
from llama_index import StorageContext, VectorStoreIndex
from llama_index.indices.base import BaseIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
RetrieverConfigType,
)
class RetrieverFactory:
def __init__(self):
self.retriever_creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
}
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return self._default_retriever(index)
retrievers = [self._get_retriever(index, config) for config in configs]
return (
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
if len(retrievers) > 1
else retrievers[0]
)
def _default_retriever(self, index: BaseIndex) -> RAGRetriever:
return index.as_retriever()
def _get_retriever(self, index: BaseIndex, config: RetrieverConfigType) -> RAGRetriever:
create_func = self.retriever_creators.get(type(config))
if create_func:
return create_func(index, config)
raise ValueError(f"Unknown retriever config: {config}")
def _create_faiss_retriever(self, index: BaseIndex, config: FAISSRetrieverConfig):
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(
nodes=list(index.docstore.docs.values()),
storage_context=storage_context,
service_context=index.service_context,
)
return FAISSRetriever(vector_index, **config.model_dump())
def _create_bm25_retriever(self, index: BaseIndex, config: BM25RetrieverConfig):
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
get_retriever = RetrieverFactory().get_retriever

View file

@ -2,7 +2,7 @@ import pytest
from llama_index import VectorStoreIndex
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
class TestSimpleEngine:
@ -99,7 +99,7 @@ class TestSimpleEngine:
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
mock_retriever = mocker.MagicMock(spec=RAGRetriever)
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"]

View file

@ -1,47 +0,0 @@
import pytest
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from metagpt.rag.rankers.factory import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
class TestRankerFactory:
@pytest.fixture
def mock_service_context(self, mocker):
return mocker.MagicMock(spec=ServiceContext)
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
# Setup
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
# Setup
config = LLMRankerConfig(top_n=3)
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
assert rankers[0].top_n == 3
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
# Mock
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
# Setup
factory = RankerFactory()
# Execute & Assertions
with pytest.raises(ValueError):
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)

View file

@ -1,12 +1,18 @@
import pytest
from llama_index import ServiceContext
from llama_index.indices.base import BaseIndex
from llama_index.postprocessor import LLMRerank
from metagpt.rag.factory import RankerFactory, RetrieverFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.factory import RetrieverFactory
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
)
class TestRetrieverFactory:
@ -28,20 +34,20 @@ class TestRetrieverFactory:
@pytest.fixture
def mock_faiss_vector_store(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.FaissVectorStore")
return mocker.patch("metagpt.rag.factory.FaissVectorStore")
@pytest.fixture
def mock_storage_context(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.StorageContext")
return mocker.patch("metagpt.rag.factory.StorageContext")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.VectorStoreIndex")
return mocker.patch("metagpt.rag.factory.VectorStoreIndex")
@pytest.fixture
def mock_dynamic_bm25_retriever(self, mocker):
mock = mocker.MagicMock(spec=DynamicBM25Retriever)
return mocker.patch("metagpt.rag.retrievers.factory.DynamicBM25Retriever", mock)
return mocker.patch("metagpt.rag.factory.DynamicBM25Retriever", mock)
def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index):
factory = RetrieverFactory()
@ -81,3 +87,44 @@ class TestRetrieverFactory:
factory = RetrieverFactory()
with pytest.raises(ValueError):
factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config])
class TestRankerFactory:
@pytest.fixture
def mock_service_context(self, mocker):
return mocker.MagicMock(spec=ServiceContext)
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
# Setup
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
# Setup
config = LLMRankerConfig(top_n=3)
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
assert rankers[0].top_n == 3
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
# Mock
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
# Setup
factory = RankerFactory()
# Execute & Assertions
with pytest.raises(ValueError):
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)