mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 20:32:38 +02:00
simplify rag factory
This commit is contained in:
parent
a4c095300c
commit
dd965a2149
10 changed files with 183 additions and 167 deletions
|
|
@ -14,10 +14,9 @@ from llama_index.query_engine import RetrieverQueryEngine
|
|||
from llama_index.response_synthesizers import BaseSynthesizer
|
||||
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
|
||||
|
||||
from metagpt.rag.factory import get_rankers, get_retriever
|
||||
from metagpt.rag.llm import get_default_llm
|
||||
from metagpt.rag.rankers import get_rankers
|
||||
from metagpt.rag.retrievers import get_retriever
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
|
@ -93,8 +92,10 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return await super().aretrieve(query_bundle)
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
"""Add docs to retriever"""
|
||||
"""Add docs to retriever. retriever must has add_nodes func"""
|
||||
if not isinstance(self.retriever, ModifiableRAGRetriever):
|
||||
raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}")
|
||||
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
|
||||
retriever: RAGRetriever = self.retriever
|
||||
retriever.add_nodes(nodes)
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
|
|
|||
109
metagpt/rag/factory.py
Normal file
109
metagpt/rag/factory.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Factory for creating retriever, ranker"""
|
||||
from typing import Any, Callable
|
||||
|
||||
import faiss
|
||||
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
RankerConfigType,
|
||||
RetrieverConfigType,
|
||||
)
|
||||
|
||||
|
||||
class BaseFactory:
|
||||
"""
|
||||
A base factory class for creating instances based on provided configurations.
|
||||
It uses a registry of creator functions mapped to configuration types to instantiate objects dynamically.
|
||||
"""
|
||||
|
||||
def __init__(self, creators: dict[Any, Callable]):
|
||||
"""
|
||||
Creators is a dictionary mapping configuration types to creator functions.
|
||||
The first arg of Creator function should be config.
|
||||
"""
|
||||
self.creators = creators
|
||||
|
||||
def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]:
|
||||
if not configs:
|
||||
return [self._default_instance(**kwargs)]
|
||||
|
||||
return [self._get_instance(config, **kwargs) for config in configs]
|
||||
|
||||
def _get_instance(self, config: Any, **kwargs) -> Any:
|
||||
create_func = self.creators.get(type(config))
|
||||
if create_func:
|
||||
return create_func(config, **kwargs)
|
||||
|
||||
raise ValueError(f"Unknown config: {config}")
|
||||
|
||||
def _default_instance(self, **kwargs) -> Any:
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
|
||||
|
||||
class RetrieverFactory(BaseFactory):
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
retrievers = super().get_instances(configs, index=index)
|
||||
|
||||
return (
|
||||
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
|
||||
if len(retrievers) > 1
|
||||
else retrievers[0]
|
||||
)
|
||||
|
||||
def _default_instance(self, index: BaseIndex) -> RAGRetriever:
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
vector_index = VectorStoreIndex(
|
||||
nodes=list(index.docstore.docs.values()),
|
||||
storage_context=storage_context,
|
||||
service_context=index.service_context,
|
||||
)
|
||||
return FAISSRetriever(**config.model_dump(), index=vector_index)
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex, **kwargs) -> DynamicBM25Retriever:
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
|
||||
|
||||
|
||||
class RankerFactory(BaseFactory):
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rankers(
|
||||
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
|
||||
) -> list[BaseNodePostprocessor]:
|
||||
return super().get_instances(configs, service_context=service_context)
|
||||
|
||||
def _default_instance(self, service_context: ServiceContext = None) -> LLMRerank:
|
||||
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None, **kwargs) -> LLMRerank:
|
||||
return LLMRerank(**config.model_dump(), service_context=service_context)
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
|
|
@ -1,6 +1 @@
|
|||
"""Rankers init"""
|
||||
|
||||
from metagpt.rag.rankers.factory import get_rankers
|
||||
|
||||
|
||||
__all__ = ["get_rankers"]
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
"""Rankers Factory"""
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
|
||||
from metagpt.rag.schema import LLMRankerConfig, RankerConfigType
|
||||
|
||||
|
||||
class RankerFactory:
|
||||
def __init__(self):
|
||||
self.ranker_creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
|
||||
def get_rankers(
|
||||
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
|
||||
) -> list[BaseNodePostprocessor]:
|
||||
if not configs:
|
||||
return [self._default_ranker(service_context)]
|
||||
|
||||
return [self._get_ranker(config, service_context) for config in configs]
|
||||
|
||||
def _default_ranker(self, service_context: ServiceContext = None) -> LLMRerank:
|
||||
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
|
||||
|
||||
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None) -> BaseNodePostprocessor:
|
||||
create_func = self.ranker_creators.get(type(config))
|
||||
if create_func:
|
||||
return create_func(config, service_context)
|
||||
|
||||
raise ValueError(f"Unknown ranker config: {config}")
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None) -> LLMRerank:
|
||||
return LLMRerank(**config.model_dump(), service_context=service_context)
|
||||
|
||||
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
"""Retrievers init"""
|
||||
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.factory import get_retriever
|
||||
|
||||
__all__ = ["SimpleHybridRetriever", "get_retriever"]
|
||||
__all__ = ["SimpleHybridRetriever"]
|
||||
|
|
|
|||
|
|
@ -17,5 +17,16 @@ class RAGRetriever(BaseRetriever):
|
|||
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Retrieve nodes"""
|
||||
|
||||
|
||||
class ModifiableRAGRetriever(RAGRetriever):
|
||||
"""Support modification."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if any("add_nodes" in B.__dict__ for B in C.__mro__):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""To support add docs, must inplement this func"""
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
"""Retriever Factory"""
|
||||
import faiss
|
||||
from llama_index import StorageContext, VectorStoreIndex
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
RetrieverConfigType,
|
||||
)
|
||||
|
||||
|
||||
class RetrieverFactory:
|
||||
def __init__(self):
|
||||
self.retriever_creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
}
|
||||
|
||||
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return self._default_retriever(index)
|
||||
|
||||
retrievers = [self._get_retriever(index, config) for config in configs]
|
||||
|
||||
return (
|
||||
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
|
||||
if len(retrievers) > 1
|
||||
else retrievers[0]
|
||||
)
|
||||
|
||||
def _default_retriever(self, index: BaseIndex) -> RAGRetriever:
|
||||
return index.as_retriever()
|
||||
|
||||
def _get_retriever(self, index: BaseIndex, config: RetrieverConfigType) -> RAGRetriever:
|
||||
create_func = self.retriever_creators.get(type(config))
|
||||
if create_func:
|
||||
return create_func(index, config)
|
||||
|
||||
raise ValueError(f"Unknown retriever config: {config}")
|
||||
|
||||
def _create_faiss_retriever(self, index: BaseIndex, config: FAISSRetrieverConfig):
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
vector_index = VectorStoreIndex(
|
||||
nodes=list(index.docstore.docs.values()),
|
||||
storage_context=storage_context,
|
||||
service_context=index.service_context,
|
||||
)
|
||||
return FAISSRetriever(vector_index, **config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, index: BaseIndex, config: BM25RetrieverConfig):
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
from llama_index import VectorStoreIndex
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
|
||||
|
||||
class TestSimpleEngine:
|
||||
|
|
@ -99,7 +99,7 @@ class TestSimpleEngine:
|
|||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=RAGRetriever)
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
import pytest
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.rankers.factory import RankerFactory
|
||||
from metagpt.rag.schema import LLMRankerConfig
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def mock_service_context(self, mocker):
|
||||
return mocker.MagicMock(spec=ServiceContext)
|
||||
|
||||
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
|
||||
# Setup
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute
|
||||
rankers = factory.get_rankers(service_context=mock_service_context)
|
||||
|
||||
# Assertions
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
|
||||
# Setup
|
||||
config = LLMRankerConfig(top_n=3)
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute
|
||||
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
|
||||
|
||||
# Assertions
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
assert rankers[0].top_n == 3
|
||||
|
||||
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
|
||||
# Mock
|
||||
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
|
||||
|
||||
# Setup
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute & Assertions
|
||||
with pytest.raises(ValueError):
|
||||
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)
|
||||
|
|
@ -1,12 +1,18 @@
|
|||
import pytest
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.factory import RankerFactory, RetrieverFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.factory import RetrieverFactory
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
|
|
@ -28,20 +34,20 @@ class TestRetrieverFactory:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_faiss_vector_store(self, mocker):
|
||||
return mocker.patch("metagpt.rag.retrievers.factory.FaissVectorStore")
|
||||
return mocker.patch("metagpt.rag.factory.FaissVectorStore")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_context(self, mocker):
|
||||
return mocker.patch("metagpt.rag.retrievers.factory.StorageContext")
|
||||
return mocker.patch("metagpt.rag.factory.StorageContext")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.retrievers.factory.VectorStoreIndex")
|
||||
return mocker.patch("metagpt.rag.factory.VectorStoreIndex")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dynamic_bm25_retriever(self, mocker):
|
||||
mock = mocker.MagicMock(spec=DynamicBM25Retriever)
|
||||
return mocker.patch("metagpt.rag.retrievers.factory.DynamicBM25Retriever", mock)
|
||||
return mocker.patch("metagpt.rag.factory.DynamicBM25Retriever", mock)
|
||||
|
||||
def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index):
|
||||
factory = RetrieverFactory()
|
||||
|
|
@ -81,3 +87,44 @@ class TestRetrieverFactory:
|
|||
factory = RetrieverFactory()
|
||||
with pytest.raises(ValueError):
|
||||
factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config])
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def mock_service_context(self, mocker):
|
||||
return mocker.MagicMock(spec=ServiceContext)
|
||||
|
||||
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
|
||||
# Setup
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute
|
||||
rankers = factory.get_rankers(service_context=mock_service_context)
|
||||
|
||||
# Assertions
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
|
||||
# Setup
|
||||
config = LLMRankerConfig(top_n=3)
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute
|
||||
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
|
||||
|
||||
# Assertions
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
assert rankers[0].top_n == 3
|
||||
|
||||
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
|
||||
# Mock
|
||||
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
|
||||
|
||||
# Setup
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute & Assertions
|
||||
with pytest.raises(ValueError):
|
||||
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)
|
||||
Loading…
Add table
Add a link
Reference in a new issue