mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-06 22:32:38 +02:00
simplify rag factory
This commit is contained in:
parent
a4c095300c
commit
dd965a2149
10 changed files with 183 additions and 167 deletions
|
|
@ -14,10 +14,9 @@ from llama_index.query_engine import RetrieverQueryEngine
|
|||
from llama_index.response_synthesizers import BaseSynthesizer
|
||||
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
|
||||
|
||||
from metagpt.rag.factory import get_rankers, get_retriever
|
||||
from metagpt.rag.llm import get_default_llm
|
||||
from metagpt.rag.rankers import get_rankers
|
||||
from metagpt.rag.retrievers import get_retriever
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
|
@ -93,8 +92,10 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return await super().aretrieve(query_bundle)
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
"""Add docs to retriever"""
|
||||
"""Add docs to retriever. retriever must has add_nodes func"""
|
||||
if not isinstance(self.retriever, ModifiableRAGRetriever):
|
||||
raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}")
|
||||
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
|
||||
retriever: RAGRetriever = self.retriever
|
||||
retriever.add_nodes(nodes)
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
|
|
|||
109
metagpt/rag/factory.py
Normal file
109
metagpt/rag/factory.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Factory for creating retriever, ranker"""
|
||||
from typing import Any, Callable
|
||||
|
||||
import faiss
|
||||
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
RankerConfigType,
|
||||
RetrieverConfigType,
|
||||
)
|
||||
|
||||
|
||||
class BaseFactory:
|
||||
"""
|
||||
A base factory class for creating instances based on provided configurations.
|
||||
It uses a registry of creator functions mapped to configuration types to instantiate objects dynamically.
|
||||
"""
|
||||
|
||||
def __init__(self, creators: dict[Any, Callable]):
|
||||
"""
|
||||
Creators is a dictionary mapping configuration types to creator functions.
|
||||
The first arg of Creator function should be config.
|
||||
"""
|
||||
self.creators = creators
|
||||
|
||||
def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]:
|
||||
if not configs:
|
||||
return [self._default_instance(**kwargs)]
|
||||
|
||||
return [self._get_instance(config, **kwargs) for config in configs]
|
||||
|
||||
def _get_instance(self, config: Any, **kwargs) -> Any:
|
||||
create_func = self.creators.get(type(config))
|
||||
if create_func:
|
||||
return create_func(config, **kwargs)
|
||||
|
||||
raise ValueError(f"Unknown config: {config}")
|
||||
|
||||
def _default_instance(self, **kwargs) -> Any:
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
|
||||
|
||||
class RetrieverFactory(BaseFactory):
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
retrievers = super().get_instances(configs, index=index)
|
||||
|
||||
return (
|
||||
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
|
||||
if len(retrievers) > 1
|
||||
else retrievers[0]
|
||||
)
|
||||
|
||||
def _default_instance(self, index: BaseIndex) -> RAGRetriever:
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
vector_index = VectorStoreIndex(
|
||||
nodes=list(index.docstore.docs.values()),
|
||||
storage_context=storage_context,
|
||||
service_context=index.service_context,
|
||||
)
|
||||
return FAISSRetriever(**config.model_dump(), index=vector_index)
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex, **kwargs) -> DynamicBM25Retriever:
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
|
||||
|
||||
|
||||
class RankerFactory(BaseFactory):
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rankers(
|
||||
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
|
||||
) -> list[BaseNodePostprocessor]:
|
||||
return super().get_instances(configs, service_context=service_context)
|
||||
|
||||
def _default_instance(self, service_context: ServiceContext = None) -> LLMRerank:
|
||||
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None, **kwargs) -> LLMRerank:
|
||||
return LLMRerank(**config.model_dump(), service_context=service_context)
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
|
|
@ -1,6 +1 @@
|
|||
"""Rankers init"""
|
||||
|
||||
from metagpt.rag.rankers.factory import get_rankers
|
||||
|
||||
|
||||
__all__ = ["get_rankers"]
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
"""Rankers Factory"""
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
|
||||
from metagpt.rag.schema import LLMRankerConfig, RankerConfigType
|
||||
|
||||
|
||||
class RankerFactory:
|
||||
def __init__(self):
|
||||
self.ranker_creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
|
||||
def get_rankers(
|
||||
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
|
||||
) -> list[BaseNodePostprocessor]:
|
||||
if not configs:
|
||||
return [self._default_ranker(service_context)]
|
||||
|
||||
return [self._get_ranker(config, service_context) for config in configs]
|
||||
|
||||
def _default_ranker(self, service_context: ServiceContext = None) -> LLMRerank:
|
||||
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
|
||||
|
||||
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None) -> BaseNodePostprocessor:
|
||||
create_func = self.ranker_creators.get(type(config))
|
||||
if create_func:
|
||||
return create_func(config, service_context)
|
||||
|
||||
raise ValueError(f"Unknown ranker config: {config}")
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None) -> LLMRerank:
|
||||
return LLMRerank(**config.model_dump(), service_context=service_context)
|
||||
|
||||
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
"""Retrievers init"""
|
||||
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.factory import get_retriever
|
||||
|
||||
__all__ = ["SimpleHybridRetriever", "get_retriever"]
|
||||
__all__ = ["SimpleHybridRetriever"]
|
||||
|
|
|
|||
|
|
@ -17,5 +17,16 @@ class RAGRetriever(BaseRetriever):
|
|||
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Retrieve nodes"""
|
||||
|
||||
|
||||
class ModifiableRAGRetriever(RAGRetriever):
|
||||
"""Support modification."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if any("add_nodes" in B.__dict__ for B in C.__mro__):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""To support add docs, must inplement this func"""
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
"""Retriever Factory"""
|
||||
import faiss
|
||||
from llama_index import StorageContext, VectorStoreIndex
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
RetrieverConfigType,
|
||||
)
|
||||
|
||||
|
||||
class RetrieverFactory:
|
||||
def __init__(self):
|
||||
self.retriever_creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
}
|
||||
|
||||
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return self._default_retriever(index)
|
||||
|
||||
retrievers = [self._get_retriever(index, config) for config in configs]
|
||||
|
||||
return (
|
||||
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
|
||||
if len(retrievers) > 1
|
||||
else retrievers[0]
|
||||
)
|
||||
|
||||
def _default_retriever(self, index: BaseIndex) -> RAGRetriever:
|
||||
return index.as_retriever()
|
||||
|
||||
def _get_retriever(self, index: BaseIndex, config: RetrieverConfigType) -> RAGRetriever:
|
||||
create_func = self.retriever_creators.get(type(config))
|
||||
if create_func:
|
||||
return create_func(index, config)
|
||||
|
||||
raise ValueError(f"Unknown retriever config: {config}")
|
||||
|
||||
def _create_faiss_retriever(self, index: BaseIndex, config: FAISSRetrieverConfig):
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
vector_index = VectorStoreIndex(
|
||||
nodes=list(index.docstore.docs.values()),
|
||||
storage_context=storage_context,
|
||||
service_context=index.service_context,
|
||||
)
|
||||
return FAISSRetriever(vector_index, **config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, index: BaseIndex, config: BM25RetrieverConfig):
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
Loading…
Add table
Add a link
Reference in a new issue