simplify rag factory

This commit is contained in:
seehi 2024-02-07 15:20:21 +08:00
parent a4c095300c
commit dd965a2149
10 changed files with 183 additions and 167 deletions

View file

@ -2,7 +2,7 @@ import pytest
from llama_index import VectorStoreIndex
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
class TestSimpleEngine:
@ -99,7 +99,7 @@ class TestSimpleEngine:
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
mock_retriever = mocker.MagicMock(spec=RAGRetriever)
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"]

View file

@ -1,47 +0,0 @@
import pytest
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from metagpt.rag.rankers.factory import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
class TestRankerFactory:
@pytest.fixture
def mock_service_context(self, mocker):
return mocker.MagicMock(spec=ServiceContext)
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
# Setup
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
# Setup
config = LLMRankerConfig(top_n=3)
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
assert rankers[0].top_n == 3
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
# Mock
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
# Setup
factory = RankerFactory()
# Execute & Assertions
with pytest.raises(ValueError):
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)

View file

@ -1,12 +1,18 @@
import pytest
from llama_index import ServiceContext
from llama_index.indices.base import BaseIndex
from llama_index.postprocessor import LLMRerank
from metagpt.rag.factory import RankerFactory, RetrieverFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.factory import RetrieverFactory
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
)
class TestRetrieverFactory:
@ -28,20 +34,20 @@ class TestRetrieverFactory:
@pytest.fixture
def mock_faiss_vector_store(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.FaissVectorStore")
return mocker.patch("metagpt.rag.factory.FaissVectorStore")
@pytest.fixture
def mock_storage_context(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.StorageContext")
return mocker.patch("metagpt.rag.factory.StorageContext")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.VectorStoreIndex")
return mocker.patch("metagpt.rag.factory.VectorStoreIndex")
@pytest.fixture
def mock_dynamic_bm25_retriever(self, mocker):
mock = mocker.MagicMock(spec=DynamicBM25Retriever)
return mocker.patch("metagpt.rag.retrievers.factory.DynamicBM25Retriever", mock)
return mocker.patch("metagpt.rag.factory.DynamicBM25Retriever", mock)
def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index):
factory = RetrieverFactory()
@ -81,3 +87,44 @@ class TestRetrieverFactory:
factory = RetrieverFactory()
with pytest.raises(ValueError):
factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config])
class TestRankerFactory:
@pytest.fixture
def mock_service_context(self, mocker):
return mocker.MagicMock(spec=ServiceContext)
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
# Setup
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
# Setup
config = LLMRankerConfig(top_n=3)
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
assert rankers[0].top_n == 3
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
# Mock
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
# Setup
factory = RankerFactory()
# Execute & Assertions
with pytest.raises(ValueError):
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)