add rag pipeline unittest

This commit is contained in:
seehi 2024-02-06 20:15:03 +08:00 committed by betterwang
parent bc4848ab1e
commit ee31295b7d
12 changed files with 355 additions and 85 deletions

View file

@ -1,67 +1,115 @@
# from unittest.mock import AsyncMock
import pytest
from llama_index import VectorStoreIndex
# import pytest
# from metagpt.rag.engines import SimpleEngine
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import RAGRetriever
# class TestSimpleEngineFromDocs:
# def test_from_docs(self, mocker):
# # Mock dependencies
# mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
# mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
class TestSimpleEngine:
def test_from_docs(self, mocker):
# Mock
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
# mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
# mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
# mock_vector_index_retriever = mocker.patch("metagpt.rag.engines.simple.VectorIndexRetriever")
mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
mock_service_context.return_value = "service_context"
# # Setup
# input_dir = "test_dir"
# input_files = ["test_file1", "test_file2"]
# embed_model = mocker.MagicMock()
# llm = mocker.MagicMock()
# chunk_size = 100
# chunk_overlap = 10
# similarity_top_k = 5
mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
mock_get_retriever = mocker.patch("metagpt.rag.engines.simple.get_retriever")
mock_get_rankers = mocker.patch("metagpt.rag.engines.simple.get_rankers")
# # Execute
# engine = SimpleEngine.from_docs(
# input_dir=input_dir,
# input_files=input_files,
# embed_model=embed_model,
# llm=llm,
# chunk_size=chunk_size,
# chunk_overlap=chunk_overlap,
# similarity_top_k=similarity_top_k,
# )
# Setup
input_dir = "test_dir"
input_files = ["test_file1", "test_file2"]
embed_model = mocker.MagicMock()
llm = mocker.MagicMock()
chunk_size = 100
chunk_overlap = 10
retriever_configs = mocker.MagicMock()
ranker_configs = mocker.MagicMock()
# # Assertions
# mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
# mock_service_context.assert_called_once_with(
# embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
# )
# mock_vector_store_index.assert_called_once_with(
# ["document1", "document2"], service_context=mock_service_context.return_value
# )
# mock_vector_index_retriever.assert_called_once_with(
# index=mock_vector_store_index.return_value, similarity_top_k=similarity_top_k
# )
# assert isinstance(engine, SimpleEngine)
# Execute
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
embed_model=embed_model,
llm=llm,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
# @pytest.mark.asyncio
# async def test_asearch_calls_aquery(self, mocker):
# # Mock
# test_query = "test query"
# expected_result = "expected result"
# mock_aquery = AsyncMock(return_value=expected_result)
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_service_context.assert_called_once_with(
embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
)
mock_vector_store_index.assert_called_once_with(
["document1", "document2"], service_context=mock_service_context.return_value
)
mock_get_retriever.assert_called_once_with(mock_vector_store_index.return_value, configs=retriever_configs)
mock_get_rankers.assert_called_once_with(
configs=ranker_configs, service_context=mock_service_context.return_value
)
# # Setup
# engine = SimpleEngine(retriever=mocker.MagicMock())
# engine.aquery = mock_aquery
assert isinstance(engine, SimpleEngine)
# # Execute
# result = await engine.asearch(test_query)
@pytest.mark.asyncio
async def test_asearch(self, mocker):
# Mock
test_query = "test query"
expected_result = "expected result"
mock_aquery = mocker.AsyncMock(return_value=expected_result)
# # Assertions
# mock_aquery.assert_called_once_with(test_query)
# assert result == expected_result
# Setup
engine = SimpleEngine(retriever=mocker.MagicMock())
engine.aquery = mock_aquery
# Execute
result = await engine.asearch(test_query)
# Assertions
mock_aquery.assert_called_once_with(test_query)
assert result == expected_result
@pytest.mark.asyncio
async def test_aretrieve(self, mocker):
# Mock
mock_query_bundle = mocker.patch("metagpt.rag.engines.simple.QueryBundle", return_value="query_bundle")
mock_super_aretrieve = mocker.patch(
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
)
mock_super_aretrieve.return_value = ["node_with_score"]
# Setup
engine = SimpleEngine(retriever=mocker.MagicMock())
test_query = "test query"
# Execute
result = await engine.aretrieve(test_query)
# Assertions
mock_query_bundle.assert_called_once_with(test_query)
mock_super_aretrieve.assert_called_once_with("query_bundle")
assert result == ["node_with_score"]
def test_add_docs(self, mocker):
# Mock
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
mock_retriever = mocker.MagicMock(spec=RAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"]
# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
input_files = ["test_file1", "test_file2"]
# Execute
engine.add_docs(input_files=input_files)
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])

View file

@ -0,0 +1,47 @@
import pytest
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from metagpt.rag.rankers.factory import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
class TestRankerFactory:
@pytest.fixture
def mock_service_context(self, mocker):
return mocker.MagicMock(spec=ServiceContext)
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
# Setup
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
# Setup
config = LLMRankerConfig(top_n=3)
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
assert rankers[0].top_n == 3
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
# Mock
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
# Setup
factory = RankerFactory()
# Execute & Assertions
with pytest.raises(ValueError):
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)

View file

@ -0,0 +1,33 @@
import pytest
from llama_index.schema import Node
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
class TestDynamicBM25Retriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc1.get_content.return_value = "Document content 1"
self.doc2 = mocker.MagicMock(spec=Node)
self.doc2.get_content.return_value = "Document content 2"
self.mock_nodes = [self.doc1, self.doc2]
# 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi")
# 初始化DynamicBM25Retriever对象并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer)
def test_add_docs_updates_nodes_and_corpus(self):
# Execute
self.retriever.add_nodes(self.mock_nodes)
# Assertions
assert len(self.retriever._nodes) == len(self.mock_nodes)
assert len(self.retriever._corpus) == len(self.mock_nodes)
self.retriever._tokenizer.assert_called()
self.mock_bm25okapi.assert_called()

View file

@ -0,0 +1,22 @@
import pytest
from llama_index.schema import Node
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
class TestFAISSRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
# 模拟FAISSRetriever的_index属性
self.mock_index = mocker.MagicMock()
self.retriever = FAISSRetriever(self.mock_index)
def test_add_docs_calls_insert_for_each_document(self, mocker):
self.retriever.add_nodes(self.mock_nodes)
assert self.mock_index.insert_nodes.assert_called

View file

@ -0,0 +1,83 @@
import pytest
from llama_index.indices.base import BaseIndex
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.factory import RetrieverFactory
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
class TestRetrieverFactory:
@pytest.fixture
def mock_base_index(self, mocker):
mock = mocker.MagicMock(spec=BaseIndex)
mock.as_retriever.return_value = mocker.MagicMock(spec=RAGRetriever)
mock.service_context = mocker.MagicMock()
mock.docstore.docs.values.return_value = []
return mock
@pytest.fixture
def mock_faiss_retriever_config(self):
return FAISSRetrieverConfig(dimensions=128)
@pytest.fixture
def mock_bm25_retriever_config(self):
return BM25RetrieverConfig()
@pytest.fixture
def mock_faiss_vector_store(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.FaissVectorStore")
@pytest.fixture
def mock_storage_context(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.StorageContext")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.retrievers.factory.VectorStoreIndex")
@pytest.fixture
def mock_dynamic_bm25_retriever(self, mocker):
mock = mocker.MagicMock(spec=DynamicBM25Retriever)
return mocker.patch("metagpt.rag.retrievers.factory.DynamicBM25Retriever", mock)
def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index):
factory = RetrieverFactory()
retriever = factory.get_retriever(index=mock_base_index)
assert isinstance(retriever, RAGRetriever)
def test_get_retriever_with_specific_config_returns_correct_retriever(
self,
mock_base_index,
mock_faiss_retriever_config,
mock_faiss_vector_store,
mock_storage_context,
mock_vector_store_index,
):
factory = RetrieverFactory()
retriever = factory.get_retriever(index=mock_base_index, configs=[mock_faiss_retriever_config])
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_multiple_configs_returns_hybrid_retriever(
self,
mock_base_index,
mock_faiss_retriever_config,
mock_bm25_retriever_config,
mock_faiss_vector_store,
mock_storage_context,
mock_vector_store_index,
mock_dynamic_bm25_retriever,
):
factory = RetrieverFactory()
retriever = factory.get_retriever(
index=mock_base_index, configs=[mock_faiss_retriever_config, mock_bm25_retriever_config]
)
assert isinstance(retriever, SimpleHybridRetriever)
def test_get_retriever_with_unknown_config_raises_value_error(self, mock_base_index, mocker):
mock_unknown_config = mocker.MagicMock()
factory = RetrieverFactory()
with pytest.raises(ValueError):
factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config])