simplify rag factory

This commit is contained in:
seehi 2024-02-07 15:20:21 +08:00
parent a4c095300c
commit dd965a2149
10 changed files with 183 additions and 167 deletions

View file

@ -14,10 +14,9 @@ from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import BaseSynthesizer
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
from metagpt.rag.factory import get_rankers, get_retriever
from metagpt.rag.llm import get_default_llm
from metagpt.rag.rankers import get_rankers
from metagpt.rag.retrievers import get_retriever
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
from metagpt.utils.embedding import get_embedding
@ -93,8 +92,10 @@ class SimpleEngine(RetrieverQueryEngine):
return await super().aretrieve(query_bundle)
def add_docs(self, input_files: list[str]):
"""Add docs to retriever"""
"""Add docs to retriever. retriever must has add_nodes func"""
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}")
documents = SimpleDirectoryReader(input_files=input_files).load_data()
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
retriever: RAGRetriever = self.retriever
retriever.add_nodes(nodes)
self.retriever.add_nodes(nodes)

109
metagpt/rag/factory.py Normal file
View file

@ -0,0 +1,109 @@
"""Factory for creating retriever, ranker"""
from typing import Any, Callable
import faiss
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
from llama_index.indices.base import BaseIndex
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
RankerConfigType,
RetrieverConfigType,
)
class BaseFactory:
"""
A base factory class for creating instances based on provided configurations.
It uses a registry of creator functions mapped to configuration types to instantiate objects dynamically.
"""
def __init__(self, creators: dict[Any, Callable]):
"""
Creators is a dictionary mapping configuration types to creator functions.
The first arg of Creator function should be config.
"""
self.creators = creators
def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]:
if not configs:
return [self._default_instance(**kwargs)]
return [self._get_instance(config, **kwargs) for config in configs]
def _get_instance(self, config: Any, **kwargs) -> Any:
create_func = self.creators.get(type(config))
if create_func:
return create_func(config, **kwargs)
raise ValueError(f"Unknown config: {config}")
def _default_instance(self, **kwargs) -> Any:
raise NotImplementedError("This method should be implemented by subclasses.")
class RetrieverFactory(BaseFactory):
def __init__(self):
creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
}
super().__init__(creators)
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations."""
retrievers = super().get_instances(configs, index=index)
return (
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
if len(retrievers) > 1
else retrievers[0]
)
def _default_instance(self, index: BaseIndex) -> RAGRetriever:
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(
nodes=list(index.docstore.docs.values()),
storage_context=storage_context,
service_context=index.service_context,
)
return FAISSRetriever(**config.model_dump(), index=vector_index)
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex, **kwargs) -> DynamicBM25Retriever:
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
class RankerFactory(BaseFactory):
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
}
super().__init__(creators)
def get_rankers(
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
return super().get_instances(configs, service_context=service_context)
def _default_instance(self, service_context: ServiceContext = None) -> LLMRerank:
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None, **kwargs) -> LLMRerank:
return LLMRerank(**config.model_dump(), service_context=service_context)
get_retriever = RetrieverFactory().get_retriever
get_rankers = RankerFactory().get_rankers

View file

@ -1,6 +1 @@
"""Rankers init"""
from metagpt.rag.rankers.factory import get_rankers
__all__ = ["get_rankers"]

View file

@ -1,37 +0,0 @@
"""Rankers Factory"""
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
from metagpt.rag.schema import LLMRankerConfig, RankerConfigType
class RankerFactory:
def __init__(self):
self.ranker_creators = {
LLMRankerConfig: self._create_llm_ranker,
}
def get_rankers(
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
if not configs:
return [self._default_ranker(service_context)]
return [self._get_ranker(config, service_context) for config in configs]
def _default_ranker(self, service_context: ServiceContext = None) -> LLMRerank:
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None) -> BaseNodePostprocessor:
create_func = self.ranker_creators.get(type(config))
if create_func:
return create_func(config, service_context)
raise ValueError(f"Unknown ranker config: {config}")
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None) -> LLMRerank:
return LLMRerank(**config.model_dump(), service_context=service_context)
get_rankers = RankerFactory().get_rankers

View file

@ -1,6 +1,5 @@
"""Retrievers init"""
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.factory import get_retriever
__all__ = ["SimpleHybridRetriever", "get_retriever"]
__all__ = ["SimpleHybridRetriever"]

View file

@ -17,5 +17,16 @@ class RAGRetriever(BaseRetriever):
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Retrieve nodes"""
class ModifiableRAGRetriever(RAGRetriever):
"""Support modification."""
@classmethod
def __subclasshook__(cls, C):
if any("add_nodes" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented
@abstractmethod
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""To support add docs, must inplement this func"""

View file

@ -1,62 +0,0 @@
"""Retriever Factory"""
import faiss
from llama_index import StorageContext, VectorStoreIndex
from llama_index.indices.base import BaseIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
RetrieverConfigType,
)
class RetrieverFactory:
def __init__(self):
self.retriever_creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
}
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return self._default_retriever(index)
retrievers = [self._get_retriever(index, config) for config in configs]
return (
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
if len(retrievers) > 1
else retrievers[0]
)
def _default_retriever(self, index: BaseIndex) -> RAGRetriever:
return index.as_retriever()
def _get_retriever(self, index: BaseIndex, config: RetrieverConfigType) -> RAGRetriever:
create_func = self.retriever_creators.get(type(config))
if create_func:
return create_func(index, config)
raise ValueError(f"Unknown retriever config: {config}")
def _create_faiss_retriever(self, index: BaseIndex, config: FAISSRetrieverConfig):
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(
nodes=list(index.docstore.docs.values()),
storage_context=storage_context,
service_context=index.service_context,
)
return FAISSRetriever(vector_index, **config.model_dump())
def _create_bm25_retriever(self, index: BaseIndex, config: BM25RetrieverConfig):
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
get_retriever = RetrieverFactory().get_retriever