mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add rag pipeline unittest
This commit is contained in:
parent
254088b026
commit
a4c095300c
12 changed files with 355 additions and 85 deletions
|
|
@ -1,10 +1,17 @@
|
|||
"""Simple Engine."""
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.callbacks.base import CallbackManager
|
||||
from llama_index.core.base_retriever import BaseRetriever
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.query_engine import RetrieverQueryEngine
|
||||
from llama_index.response_synthesizers import BaseSynthesizer
|
||||
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
|
||||
|
||||
from metagpt.rag.llm import get_default_llm
|
||||
|
|
@ -16,6 +23,29 @@ from metagpt.utils.embedding import get_embedding
|
|||
|
||||
|
||||
class SimpleEngine(RetrieverQueryEngine):
|
||||
"""
|
||||
SimpleEngine is a lightweight and easy-to-use search engine that integrates
|
||||
document reading, embedding, indexing, retrieving, and ranking functionalities
|
||||
into a single, straightforward workflow. It is designed to quickly set up a
|
||||
search engine from a collection of documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: BaseRetriever,
|
||||
response_synthesizer: Optional[BaseSynthesizer] = None,
|
||||
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
index: Optional[BaseIndex] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
retriever=retriever,
|
||||
response_synthesizer=response_synthesizer,
|
||||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
cls,
|
||||
|
|
@ -31,9 +61,14 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""This engine is designed to be simple and straightforward
|
||||
|
||||
Args:
|
||||
input_dir (str): Path to the directory.
|
||||
input_files (list): List of file paths to read
|
||||
(Optional; overrides input_dir, exclude)
|
||||
input_dir: Path to the directory.
|
||||
input_files: List of file paths to read (Optional; overrides input_dir, exclude).
|
||||
llm: Must supported by llama index.
|
||||
embed_model: Must supported by llama index.
|
||||
chunk_size: The size of text chunks (in tokens) to split documents into for embedding.
|
||||
chunk_overlap: The number of tokens for overlapping between consecutive chunks. Helps in maintaining context continuity.
|
||||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
service_context = ServiceContext.from_defaults(
|
||||
|
|
@ -46,7 +81,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever = get_retriever(index, configs=retriever_configs)
|
||||
rankers = get_rankers(configs=ranker_configs, service_context=service_context)
|
||||
|
||||
return SimpleEngine(retriever=retriever, node_postprocessors=rankers)
|
||||
return cls(retriever=retriever, node_postprocessors=rankers, index=index)
|
||||
|
||||
async def asearch(self, content: str, **kwargs) -> str:
|
||||
"""Inplement tools.SearchInterface"""
|
||||
|
|
@ -58,6 +93,8 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return await super().aretrieve(query_bundle)
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
"""Add docs to retriever"""
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
|
||||
retriever: RAGRetriever = self.retriever
|
||||
retriever.add_docs(documents)
|
||||
retriever.add_nodes(nodes)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
"""Rankers Factory"""
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
|
|
@ -19,18 +20,18 @@ class RankerFactory:
|
|||
|
||||
return [self._get_ranker(config, service_context) for config in configs]
|
||||
|
||||
def _default_ranker(self, service_context: ServiceContext = None):
|
||||
def _default_ranker(self, service_context: ServiceContext = None) -> LLMRerank:
|
||||
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
|
||||
|
||||
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None):
|
||||
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None) -> BaseNodePostprocessor:
|
||||
create_func = self.ranker_creators.get(type(config))
|
||||
if create_func:
|
||||
return create_func(config, service_context)
|
||||
|
||||
raise ValueError(f"Unknown ranker config: {config}")
|
||||
|
||||
def _create_llm_ranker(self, config, service_context=None):
|
||||
return LLMRerank(top_n=config.top_n, service_context=service_context)
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None) -> LLMRerank:
|
||||
return LLMRerank(**config.model_dump(), service_context=service_context)
|
||||
|
||||
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
|
|
|
|||
|
|
@ -3,21 +3,19 @@
|
|||
|
||||
from abc import abstractmethod
|
||||
|
||||
from llama_index import Document
|
||||
from llama_index.retrievers import BaseRetriever
|
||||
from llama_index.schema import NodeWithScore, QueryType
|
||||
from llama_index.schema import BaseNode, NodeWithScore, QueryType
|
||||
|
||||
|
||||
class RAGRetriever(BaseRetriever):
|
||||
"""inherit from llama_index"""
|
||||
"""Inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""retrieve nodes"""
|
||||
|
||||
@abstractmethod
|
||||
def add_docs(self, documents: list[Document]) -> None:
|
||||
"""add docs"""
|
||||
"""Retrieve nodes"""
|
||||
|
||||
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""retrieve nodes"""
|
||||
"""Retrieve nodes"""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""To support add docs, must inplement this func"""
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
from llama_index import Document
|
||||
from llama_index.retrievers import BM25Retriever
|
||||
from llama_index.schema import BaseNode
|
||||
|
||||
|
||||
class DynamicBM25Retriever(BM25Retriever):
|
||||
def add_docs(self, documents: list[Document]):
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
raise ImportError("Please install rank_bm25: pip install rank-bm25")
|
||||
|
||||
self._nodes.extend(documents)
|
||||
self._nodes.extend(nodes)
|
||||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
"""Retriever Factory"""
|
||||
import faiss
|
||||
from llama_index import StorageContext, VectorStoreIndex
|
||||
from llama_index.indices.base import BaseIndex
|
||||
|
|
@ -22,6 +23,7 @@ class RetrieverFactory:
|
|||
}
|
||||
|
||||
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return self._default_retriever(index)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from llama_index import Document
|
||||
from llama_index.retrievers import VectorIndexRetriever
|
||||
from llama_index.schema import BaseNode
|
||||
|
||||
|
||||
class FAISSRetriever(VectorIndexRetriever):
|
||||
def add_docs(self, documents: list[Document]):
|
||||
for document in documents:
|
||||
self._index.insert(document)
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Hybrid retriever."""
|
||||
from llama_index import Document, ServiceContext
|
||||
from llama_index.schema import QueryType
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.schema import BaseNode, QueryType
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
|
||||
|
|
@ -37,6 +37,6 @@ class SimpleHybridRetriever(RAGRetriever):
|
|||
node_ids.add(n.node.node_id)
|
||||
return result
|
||||
|
||||
def add_docs(self, documents: list[Document]):
|
||||
def add_nodes(self, nodes: list[BaseNode]):
|
||||
for r in self.retrievers:
|
||||
r.add_docs(documents)
|
||||
r.add_nodes(nodes)
|
||||
|
|
|
|||
|
|
@ -1,67 +1,115 @@
|
|||
# from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from llama_index import VectorStoreIndex
|
||||
|
||||
# import pytest
|
||||
|
||||
# from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
|
||||
|
||||
# class TestSimpleEngineFromDocs:
|
||||
# def test_from_docs(self, mocker):
|
||||
# # Mock dependencies
|
||||
# mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
# mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
|
||||
class TestSimpleEngine:
|
||||
def test_from_docs(self, mocker):
|
||||
# Mock
|
||||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
|
||||
|
||||
# mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
|
||||
# mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
# mock_vector_index_retriever = mocker.patch("metagpt.rag.engines.simple.VectorIndexRetriever")
|
||||
mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
|
||||
mock_service_context.return_value = "service_context"
|
||||
|
||||
# # Setup
|
||||
# input_dir = "test_dir"
|
||||
# input_files = ["test_file1", "test_file2"]
|
||||
# embed_model = mocker.MagicMock()
|
||||
# llm = mocker.MagicMock()
|
||||
# chunk_size = 100
|
||||
# chunk_overlap = 10
|
||||
# similarity_top_k = 5
|
||||
mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
mock_get_retriever = mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
mock_get_rankers = mocker.patch("metagpt.rag.engines.simple.get_rankers")
|
||||
|
||||
# # Execute
|
||||
# engine = SimpleEngine.from_docs(
|
||||
# input_dir=input_dir,
|
||||
# input_files=input_files,
|
||||
# embed_model=embed_model,
|
||||
# llm=llm,
|
||||
# chunk_size=chunk_size,
|
||||
# chunk_overlap=chunk_overlap,
|
||||
# similarity_top_k=similarity_top_k,
|
||||
# )
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
embed_model = mocker.MagicMock()
|
||||
llm = mocker.MagicMock()
|
||||
chunk_size = 100
|
||||
chunk_overlap = 10
|
||||
retriever_configs = mocker.MagicMock()
|
||||
ranker_configs = mocker.MagicMock()
|
||||
|
||||
# # Assertions
|
||||
# mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
# mock_service_context.assert_called_once_with(
|
||||
# embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
|
||||
# )
|
||||
# mock_vector_store_index.assert_called_once_with(
|
||||
# ["document1", "document2"], service_context=mock_service_context.return_value
|
||||
# )
|
||||
# mock_vector_index_retriever.assert_called_once_with(
|
||||
# index=mock_vector_store_index.return_value, similarity_top_k=similarity_top_k
|
||||
# )
|
||||
# assert isinstance(engine, SimpleEngine)
|
||||
# Execute
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_asearch_calls_aquery(self, mocker):
|
||||
# # Mock
|
||||
# test_query = "test query"
|
||||
# expected_result = "expected result"
|
||||
# mock_aquery = AsyncMock(return_value=expected_result)
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_service_context.assert_called_once_with(
|
||||
embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
|
||||
)
|
||||
mock_vector_store_index.assert_called_once_with(
|
||||
["document1", "document2"], service_context=mock_service_context.return_value
|
||||
)
|
||||
mock_get_retriever.assert_called_once_with(mock_vector_store_index.return_value, configs=retriever_configs)
|
||||
mock_get_rankers.assert_called_once_with(
|
||||
configs=ranker_configs, service_context=mock_service_context.return_value
|
||||
)
|
||||
|
||||
# # Setup
|
||||
# engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
# engine.aquery = mock_aquery
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
# # Execute
|
||||
# result = await engine.asearch(test_query)
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
# Mock
|
||||
test_query = "test query"
|
||||
expected_result = "expected result"
|
||||
mock_aquery = mocker.AsyncMock(return_value=expected_result)
|
||||
|
||||
# # Assertions
|
||||
# mock_aquery.assert_called_once_with(test_query)
|
||||
# assert result == expected_result
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
engine.aquery = mock_aquery
|
||||
|
||||
# Execute
|
||||
result = await engine.asearch(test_query)
|
||||
|
||||
# Assertions
|
||||
mock_aquery.assert_called_once_with(test_query)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self, mocker):
|
||||
# Mock
|
||||
mock_query_bundle = mocker.patch("metagpt.rag.engines.simple.QueryBundle", return_value="query_bundle")
|
||||
mock_super_aretrieve = mocker.patch(
|
||||
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
|
||||
)
|
||||
mock_super_aretrieve.return_value = ["node_with_score"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
test_query = "test query"
|
||||
|
||||
# Execute
|
||||
result = await engine.aretrieve(test_query)
|
||||
|
||||
# Assertions
|
||||
mock_query_bundle.assert_called_once_with(test_query)
|
||||
mock_super_aretrieve.assert_called_once_with("query_bundle")
|
||||
assert result == ["node_with_score"]
|
||||
|
||||
def test_add_docs(self, mocker):
|
||||
# Mock
|
||||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=RAGRetriever)
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Execute
|
||||
engine.add_docs(input_files=input_files)
|
||||
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
|
||||
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
|
||||
|
|
|
|||
47
tests/metagpt/rag/rankers/test_ranker_factory.py
Normal file
47
tests/metagpt/rag/rankers/test_ranker_factory.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import pytest
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.rankers.factory import RankerFactory
|
||||
from metagpt.rag.schema import LLMRankerConfig
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def mock_service_context(self, mocker):
|
||||
return mocker.MagicMock(spec=ServiceContext)
|
||||
|
||||
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
|
||||
# Setup
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute
|
||||
rankers = factory.get_rankers(service_context=mock_service_context)
|
||||
|
||||
# Assertions
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
|
||||
# Setup
|
||||
config = LLMRankerConfig(top_n=3)
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute
|
||||
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
|
||||
|
||||
# Assertions
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
assert rankers[0].top_n == 3
|
||||
|
||||
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
|
||||
# Mock
|
||||
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
|
||||
|
||||
# Setup
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute & Assertions
|
||||
with pytest.raises(ValueError):
|
||||
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)
|
||||
33
tests/metagpt/rag/retrievers/test_bm25_retriever.py
Normal file
33
tests/metagpt/rag/retrievers/test_bm25_retriever.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import pytest
|
||||
from llama_index.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
|
||||
|
||||
class TestDynamicBM25Retriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc1.get_content.return_value = "Document content 1"
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.doc2.get_content.return_value = "Document content 2"
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟nodes和tokenizer参数
|
||||
mock_nodes = []
|
||||
mock_tokenizer = mocker.MagicMock()
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi")
|
||||
|
||||
# 初始化DynamicBM25Retriever对象,并提供必需的参数
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer)
|
||||
|
||||
def test_add_docs_updates_nodes_and_corpus(self):
|
||||
# Execute
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
# Assertions
|
||||
assert len(self.retriever._nodes) == len(self.mock_nodes)
|
||||
assert len(self.retriever._corpus) == len(self.mock_nodes)
|
||||
self.retriever._tokenizer.assert_called()
|
||||
self.mock_bm25okapi.assert_called()
|
||||
22
tests/metagpt/rag/retrievers/test_faiss_retriever.py
Normal file
22
tests/metagpt/rag/retrievers/test_faiss_retriever.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import pytest
|
||||
from llama_index.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
|
||||
|
||||
class TestFAISSRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟FAISSRetriever的_index属性
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = FAISSRetriever(self.mock_index)
|
||||
|
||||
def test_add_docs_calls_insert_for_each_document(self, mocker):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
assert self.mock_index.insert_nodes.assert_called
|
||||
83
tests/metagpt/rag/retrievers/test_retriever_factory.py
Normal file
83
tests/metagpt/rag/retrievers/test_retriever_factory.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import pytest
|
||||
from llama_index.indices.base import BaseIndex
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.factory import RetrieverFactory
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
@pytest.fixture
|
||||
def mock_base_index(self, mocker):
|
||||
mock = mocker.MagicMock(spec=BaseIndex)
|
||||
mock.as_retriever.return_value = mocker.MagicMock(spec=RAGRetriever)
|
||||
mock.service_context = mocker.MagicMock()
|
||||
mock.docstore.docs.values.return_value = []
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_retriever_config(self):
|
||||
return FAISSRetrieverConfig(dimensions=128)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bm25_retriever_config(self):
|
||||
return BM25RetrieverConfig()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_vector_store(self, mocker):
|
||||
return mocker.patch("metagpt.rag.retrievers.factory.FaissVectorStore")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_context(self, mocker):
|
||||
return mocker.patch("metagpt.rag.retrievers.factory.StorageContext")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.retrievers.factory.VectorStoreIndex")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dynamic_bm25_retriever(self, mocker):
|
||||
mock = mocker.MagicMock(spec=DynamicBM25Retriever)
|
||||
return mocker.patch("metagpt.rag.retrievers.factory.DynamicBM25Retriever", mock)
|
||||
|
||||
def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index):
|
||||
factory = RetrieverFactory()
|
||||
retriever = factory.get_retriever(index=mock_base_index)
|
||||
assert isinstance(retriever, RAGRetriever)
|
||||
|
||||
def test_get_retriever_with_specific_config_returns_correct_retriever(
|
||||
self,
|
||||
mock_base_index,
|
||||
mock_faiss_retriever_config,
|
||||
mock_faiss_vector_store,
|
||||
mock_storage_context,
|
||||
mock_vector_store_index,
|
||||
):
|
||||
factory = RetrieverFactory()
|
||||
retriever = factory.get_retriever(index=mock_base_index, configs=[mock_faiss_retriever_config])
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid_retriever(
|
||||
self,
|
||||
mock_base_index,
|
||||
mock_faiss_retriever_config,
|
||||
mock_bm25_retriever_config,
|
||||
mock_faiss_vector_store,
|
||||
mock_storage_context,
|
||||
mock_vector_store_index,
|
||||
mock_dynamic_bm25_retriever,
|
||||
):
|
||||
factory = RetrieverFactory()
|
||||
retriever = factory.get_retriever(
|
||||
index=mock_base_index, configs=[mock_faiss_retriever_config, mock_bm25_retriever_config]
|
||||
)
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_get_retriever_with_unknown_config_raises_value_error(self, mock_base_index, mocker):
|
||||
mock_unknown_config = mocker.MagicMock()
|
||||
factory = RetrieverFactory()
|
||||
with pytest.raises(ValueError):
|
||||
factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config])
|
||||
Loading…
Add table
Add a link
Reference in a new issue