diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index e136b4092..c4e3b6f31 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -1,10 +1,17 @@ """Simple Engine.""" +from typing import Optional + from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex +from llama_index.callbacks.base import CallbackManager +from llama_index.core.base_retriever import BaseRetriever from llama_index.embeddings.base import BaseEmbedding +from llama_index.indices.base import BaseIndex from llama_index.llms.llm import LLM +from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.query_engine import RetrieverQueryEngine +from llama_index.response_synthesizers import BaseSynthesizer from llama_index.schema import NodeWithScore, QueryBundle, QueryType from metagpt.rag.llm import get_default_llm @@ -16,6 +23,29 @@ from metagpt.utils.embedding import get_embedding class SimpleEngine(RetrieverQueryEngine): + """ + SimpleEngine is a lightweight and easy-to-use search engine that integrates + document reading, embedding, indexing, retrieving, and ranking functionalities + into a single, straightforward workflow. It is designed to quickly set up a + search engine from a collection of documents. + """ + + def __init__( + self, + retriever: BaseRetriever, + response_synthesizer: Optional[BaseSynthesizer] = None, + node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, + callback_manager: Optional[CallbackManager] = None, + index: Optional[BaseIndex] = None, + ) -> None: + super().__init__( + retriever=retriever, + response_synthesizer=response_synthesizer, + node_postprocessors=node_postprocessors, + callback_manager=callback_manager, + ) + self.index = index + @classmethod def from_docs( cls, @@ -31,9 +61,14 @@ class SimpleEngine(RetrieverQueryEngine): """This engine is designed to be simple and straightforward Args: - input_dir (str): Path to the directory. - input_files (list): List of file paths to read - (Optional; overrides input_dir, exclude) + input_dir: Path to the directory. + input_files: List of file paths to read (Optional; overrides input_dir, exclude). + llm: Must supported by llama index. + embed_model: Must supported by llama index. + chunk_size: The size of text chunks (in tokens) to split documents into for embedding. + chunk_overlap: The number of tokens for overlapping between consecutive chunks. Helps in maintaining context continuity. + retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever. + ranker_configs: Configuration for rankers. """ documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() service_context = ServiceContext.from_defaults( @@ -46,7 +81,7 @@ class SimpleEngine(RetrieverQueryEngine): retriever = get_retriever(index, configs=retriever_configs) rankers = get_rankers(configs=ranker_configs, service_context=service_context) - return SimpleEngine(retriever=retriever, node_postprocessors=rankers) + return cls(retriever=retriever, node_postprocessors=rankers, index=index) async def asearch(self, content: str, **kwargs) -> str: """Inplement tools.SearchInterface""" @@ -58,6 +93,8 @@ class SimpleEngine(RetrieverQueryEngine): return await super().aretrieve(query_bundle) def add_docs(self, input_files: list[str]): + """Add docs to retriever""" documents = SimpleDirectoryReader(input_files=input_files).load_data() + nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents) retriever: RAGRetriever = self.retriever - retriever.add_docs(documents) + retriever.add_nodes(nodes) diff --git a/metagpt/rag/rankers/factory.py b/metagpt/rag/rankers/factory.py index 14dc89604..b139fdd92 100644 --- a/metagpt/rag/rankers/factory.py +++ b/metagpt/rag/rankers/factory.py @@ -1,3 +1,4 @@ +"""Rankers Factory""" from llama_index import ServiceContext from llama_index.postprocessor import LLMRerank from llama_index.postprocessor.types import BaseNodePostprocessor @@ -19,18 +20,18 @@ class RankerFactory: return [self._get_ranker(config, service_context) for config in configs] - def _default_ranker(self, service_context: ServiceContext = None): + def _default_ranker(self, service_context: ServiceContext = None) -> LLMRerank: return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context) - def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None): + def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None) -> BaseNodePostprocessor: create_func = self.ranker_creators.get(type(config)) if create_func: return create_func(config, service_context) raise ValueError(f"Unknown ranker config: {config}") - def _create_llm_ranker(self, config, service_context=None): - return LLMRerank(top_n=config.top_n, service_context=service_context) + def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None) -> LLMRerank: + return LLMRerank(**config.model_dump(), service_context=service_context) get_rankers = RankerFactory().get_rankers diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index 535e427c3..97590a138 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -3,21 +3,19 @@ from abc import abstractmethod -from llama_index import Document from llama_index.retrievers import BaseRetriever -from llama_index.schema import NodeWithScore, QueryType +from llama_index.schema import BaseNode, NodeWithScore, QueryType class RAGRetriever(BaseRetriever): - """inherit from llama_index""" + """Inherit from llama_index""" @abstractmethod async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]: - """retrieve nodes""" - - @abstractmethod - def add_docs(self, documents: list[Document]) -> None: - """add docs""" + """Retrieve nodes""" def _retrieve(self, query: QueryType) -> list[NodeWithScore]: - """retrieve nodes""" + """Retrieve nodes""" + + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: + """To support add docs, must inplement this func""" diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 4141827dd..c7257e00f 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -1,14 +1,14 @@ -from llama_index import Document from llama_index.retrievers import BM25Retriever +from llama_index.schema import BaseNode class DynamicBM25Retriever(BM25Retriever): - def add_docs(self, documents: list[Document]): + def add_nodes(self, nodes: list[BaseNode], **kwargs): try: from rank_bm25 import BM25Okapi except ImportError: raise ImportError("Please install rank_bm25: pip install rank-bm25") - self._nodes.extend(documents) + self._nodes.extend(nodes) self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] self.bm25 = BM25Okapi(self._corpus) diff --git a/metagpt/rag/retrievers/factory.py b/metagpt/rag/retrievers/factory.py index cde70e219..c2dcb2725 100644 --- a/metagpt/rag/retrievers/factory.py +++ b/metagpt/rag/retrievers/factory.py @@ -1,3 +1,4 @@ +"""Retriever Factory""" import faiss from llama_index import StorageContext, VectorStoreIndex from llama_index.indices.base import BaseIndex @@ -22,6 +23,7 @@ class RetrieverFactory: } def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever: + """Creates and returns a retriever instance based on the provided configurations.""" if not configs: return self._default_retriever(index) diff --git a/metagpt/rag/retrievers/faiss_retriever.py b/metagpt/rag/retrievers/faiss_retriever.py index 9888959e1..aa91aaaff 100644 --- a/metagpt/rag/retrievers/faiss_retriever.py +++ b/metagpt/rag/retrievers/faiss_retriever.py @@ -1,8 +1,7 @@ -from llama_index import Document from llama_index.retrievers import VectorIndexRetriever +from llama_index.schema import BaseNode class FAISSRetriever(VectorIndexRetriever): - def add_docs(self, documents: list[Document]): - for document in documents: - self._index.insert(document) + def add_nodes(self, nodes: list[BaseNode], **kwargs): + self._index.insert_nodes(nodes, **kwargs) diff --git a/metagpt/rag/retrievers/hybrid_retriever.py b/metagpt/rag/retrievers/hybrid_retriever.py index f4e9c3479..04889b702 100644 --- a/metagpt/rag/retrievers/hybrid_retriever.py +++ b/metagpt/rag/retrievers/hybrid_retriever.py @@ -1,6 +1,6 @@ """Hybrid retriever.""" -from llama_index import Document, ServiceContext -from llama_index.schema import QueryType +from llama_index import ServiceContext +from llama_index.schema import BaseNode, QueryType from metagpt.rag.retrievers.base import RAGRetriever @@ -37,6 +37,6 @@ class SimpleHybridRetriever(RAGRetriever): node_ids.add(n.node.node_id) return result - def add_docs(self, documents: list[Document]): + def add_nodes(self, nodes: list[BaseNode]): for r in self.retrievers: - r.add_docs(documents) + r.add_nodes(nodes) diff --git a/tests/metagpt/rag/engine/test_simple.py b/tests/metagpt/rag/engine/test_simple.py index 2bea8f556..4d047b075 100644 --- a/tests/metagpt/rag/engine/test_simple.py +++ b/tests/metagpt/rag/engine/test_simple.py @@ -1,67 +1,115 @@ -# from unittest.mock import AsyncMock +import pytest +from llama_index import VectorStoreIndex -# import pytest - -# from metagpt.rag.engines import SimpleEngine +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.retrievers.base import RAGRetriever -# class TestSimpleEngineFromDocs: -# def test_from_docs(self, mocker): -# # Mock dependencies -# mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") -# mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"] +class TestSimpleEngine: + def test_from_docs(self, mocker): + # Mock + mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") + mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"] -# mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults") -# mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") -# mock_vector_index_retriever = mocker.patch("metagpt.rag.engines.simple.VectorIndexRetriever") + mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults") + mock_service_context.return_value = "service_context" -# # Setup -# input_dir = "test_dir" -# input_files = ["test_file1", "test_file2"] -# embed_model = mocker.MagicMock() -# llm = mocker.MagicMock() -# chunk_size = 100 -# chunk_overlap = 10 -# similarity_top_k = 5 + mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") + mock_get_retriever = mocker.patch("metagpt.rag.engines.simple.get_retriever") + mock_get_rankers = mocker.patch("metagpt.rag.engines.simple.get_rankers") -# # Execute -# engine = SimpleEngine.from_docs( -# input_dir=input_dir, -# input_files=input_files, -# embed_model=embed_model, -# llm=llm, -# chunk_size=chunk_size, -# chunk_overlap=chunk_overlap, -# similarity_top_k=similarity_top_k, -# ) + # Setup + input_dir = "test_dir" + input_files = ["test_file1", "test_file2"] + embed_model = mocker.MagicMock() + llm = mocker.MagicMock() + chunk_size = 100 + chunk_overlap = 10 + retriever_configs = mocker.MagicMock() + ranker_configs = mocker.MagicMock() -# # Assertions -# mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) -# mock_service_context.assert_called_once_with( -# embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm -# ) -# mock_vector_store_index.assert_called_once_with( -# ["document1", "document2"], service_context=mock_service_context.return_value -# ) -# mock_vector_index_retriever.assert_called_once_with( -# index=mock_vector_store_index.return_value, similarity_top_k=similarity_top_k -# ) -# assert isinstance(engine, SimpleEngine) + # Execute + engine = SimpleEngine.from_docs( + input_dir=input_dir, + input_files=input_files, + embed_model=embed_model, + llm=llm, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, + ) -# @pytest.mark.asyncio -# async def test_asearch_calls_aquery(self, mocker): -# # Mock -# test_query = "test query" -# expected_result = "expected result" -# mock_aquery = AsyncMock(return_value=expected_result) + # Assertions + mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) + mock_service_context.assert_called_once_with( + embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm + ) + mock_vector_store_index.assert_called_once_with( + ["document1", "document2"], service_context=mock_service_context.return_value + ) + mock_get_retriever.assert_called_once_with(mock_vector_store_index.return_value, configs=retriever_configs) + mock_get_rankers.assert_called_once_with( + configs=ranker_configs, service_context=mock_service_context.return_value + ) -# # Setup -# engine = SimpleEngine(retriever=mocker.MagicMock()) -# engine.aquery = mock_aquery + assert isinstance(engine, SimpleEngine) -# # Execute -# result = await engine.asearch(test_query) + @pytest.mark.asyncio + async def test_asearch(self, mocker): + # Mock + test_query = "test query" + expected_result = "expected result" + mock_aquery = mocker.AsyncMock(return_value=expected_result) -# # Assertions -# mock_aquery.assert_called_once_with(test_query) -# assert result == expected_result + # Setup + engine = SimpleEngine(retriever=mocker.MagicMock()) + engine.aquery = mock_aquery + + # Execute + result = await engine.asearch(test_query) + + # Assertions + mock_aquery.assert_called_once_with(test_query) + assert result == expected_result + + @pytest.mark.asyncio + async def test_aretrieve(self, mocker): + # Mock + mock_query_bundle = mocker.patch("metagpt.rag.engines.simple.QueryBundle", return_value="query_bundle") + mock_super_aretrieve = mocker.patch( + "metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock + ) + mock_super_aretrieve.return_value = ["node_with_score"] + + # Setup + engine = SimpleEngine(retriever=mocker.MagicMock()) + test_query = "test query" + + # Execute + result = await engine.aretrieve(test_query) + + # Assertions + mock_query_bundle.assert_called_once_with(test_query) + mock_super_aretrieve.assert_called_once_with("query_bundle") + assert result == ["node_with_score"] + + def test_add_docs(self, mocker): + # Mock + mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") + mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"] + + mock_retriever = mocker.MagicMock(spec=RAGRetriever) + mock_index = mocker.MagicMock(spec=VectorStoreIndex) + mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"] + + # Setup + engine = SimpleEngine(retriever=mock_retriever, index=mock_index) + input_files = ["test_file1", "test_file2"] + + # Execute + engine.add_docs(input_files=input_files) + + # Assertions + mock_simple_directory_reader.assert_called_once_with(input_files=input_files) + mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"]) diff --git a/tests/metagpt/rag/rankers/test_ranker_factory.py b/tests/metagpt/rag/rankers/test_ranker_factory.py new file mode 100644 index 000000000..ec335cee2 --- /dev/null +++ b/tests/metagpt/rag/rankers/test_ranker_factory.py @@ -0,0 +1,47 @@ +import pytest +from llama_index import ServiceContext +from llama_index.postprocessor import LLMRerank + +from metagpt.rag.rankers.factory import RankerFactory +from metagpt.rag.schema import LLMRankerConfig + + +class TestRankerFactory: + @pytest.fixture + def mock_service_context(self, mocker): + return mocker.MagicMock(spec=ServiceContext) + + def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context): + # Setup + factory = RankerFactory() + + # Execute + rankers = factory.get_rankers(service_context=mock_service_context) + + # Assertions + assert len(rankers) == 1 + assert isinstance(rankers[0], LLMRerank) + + def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context): + # Setup + config = LLMRankerConfig(top_n=3) + factory = RankerFactory() + + # Execute + rankers = factory.get_rankers(configs=[config], service_context=mock_service_context) + + # Assertions + assert len(rankers) == 1 + assert isinstance(rankers[0], LLMRerank) + assert rankers[0].top_n == 3 + + def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context): + # Mock + mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型 + + # Setup + factory = RankerFactory() + + # Execute & Assertions + with pytest.raises(ValueError): + factory.get_rankers(configs=[mock_config], service_context=mock_service_context) diff --git a/tests/metagpt/rag/retrievers/test_bm25_retriever.py b/tests/metagpt/rag/retrievers/test_bm25_retriever.py new file mode 100644 index 000000000..cc845a35a --- /dev/null +++ b/tests/metagpt/rag/retrievers/test_bm25_retriever.py @@ -0,0 +1,33 @@ +import pytest +from llama_index.schema import Node + +from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever + + +class TestDynamicBM25Retriever: + @pytest.fixture(autouse=True) + def setup(self, mocker): + # 创建模拟的Document对象 + self.doc1 = mocker.MagicMock(spec=Node) + self.doc1.get_content.return_value = "Document content 1" + self.doc2 = mocker.MagicMock(spec=Node) + self.doc2.get_content.return_value = "Document content 2" + self.mock_nodes = [self.doc1, self.doc2] + + # 模拟nodes和tokenizer参数 + mock_nodes = [] + mock_tokenizer = mocker.MagicMock() + self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi") + + # 初始化DynamicBM25Retriever对象,并提供必需的参数 + self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer) + + def test_add_docs_updates_nodes_and_corpus(self): + # Execute + self.retriever.add_nodes(self.mock_nodes) + + # Assertions + assert len(self.retriever._nodes) == len(self.mock_nodes) + assert len(self.retriever._corpus) == len(self.mock_nodes) + self.retriever._tokenizer.assert_called() + self.mock_bm25okapi.assert_called() diff --git a/tests/metagpt/rag/retrievers/test_faiss_retriever.py b/tests/metagpt/rag/retrievers/test_faiss_retriever.py new file mode 100644 index 000000000..7d5a5a5a3 --- /dev/null +++ b/tests/metagpt/rag/retrievers/test_faiss_retriever.py @@ -0,0 +1,22 @@ +import pytest +from llama_index.schema import Node + +from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever + + +class TestFAISSRetriever: + @pytest.fixture(autouse=True) + def setup(self, mocker): + # 创建模拟的Document对象 + self.doc1 = mocker.MagicMock(spec=Node) + self.doc2 = mocker.MagicMock(spec=Node) + self.mock_nodes = [self.doc1, self.doc2] + + # 模拟FAISSRetriever的_index属性 + self.mock_index = mocker.MagicMock() + self.retriever = FAISSRetriever(self.mock_index) + + def test_add_docs_calls_insert_for_each_document(self, mocker): + self.retriever.add_nodes(self.mock_nodes) + + assert self.mock_index.insert_nodes.assert_called diff --git a/tests/metagpt/rag/retrievers/test_retriever_factory.py b/tests/metagpt/rag/retrievers/test_retriever_factory.py new file mode 100644 index 000000000..dc69a49fc --- /dev/null +++ b/tests/metagpt/rag/retrievers/test_retriever_factory.py @@ -0,0 +1,83 @@ +import pytest +from llama_index.indices.base import BaseIndex + +from metagpt.rag.retrievers.base import RAGRetriever +from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever +from metagpt.rag.retrievers.factory import RetrieverFactory +from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever +from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever +from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig + + +class TestRetrieverFactory: + @pytest.fixture + def mock_base_index(self, mocker): + mock = mocker.MagicMock(spec=BaseIndex) + mock.as_retriever.return_value = mocker.MagicMock(spec=RAGRetriever) + mock.service_context = mocker.MagicMock() + mock.docstore.docs.values.return_value = [] + return mock + + @pytest.fixture + def mock_faiss_retriever_config(self): + return FAISSRetrieverConfig(dimensions=128) + + @pytest.fixture + def mock_bm25_retriever_config(self): + return BM25RetrieverConfig() + + @pytest.fixture + def mock_faiss_vector_store(self, mocker): + return mocker.patch("metagpt.rag.retrievers.factory.FaissVectorStore") + + @pytest.fixture + def mock_storage_context(self, mocker): + return mocker.patch("metagpt.rag.retrievers.factory.StorageContext") + + @pytest.fixture + def mock_vector_store_index(self, mocker): + return mocker.patch("metagpt.rag.retrievers.factory.VectorStoreIndex") + + @pytest.fixture + def mock_dynamic_bm25_retriever(self, mocker): + mock = mocker.MagicMock(spec=DynamicBM25Retriever) + return mocker.patch("metagpt.rag.retrievers.factory.DynamicBM25Retriever", mock) + + def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index): + factory = RetrieverFactory() + retriever = factory.get_retriever(index=mock_base_index) + assert isinstance(retriever, RAGRetriever) + + def test_get_retriever_with_specific_config_returns_correct_retriever( + self, + mock_base_index, + mock_faiss_retriever_config, + mock_faiss_vector_store, + mock_storage_context, + mock_vector_store_index, + ): + factory = RetrieverFactory() + retriever = factory.get_retriever(index=mock_base_index, configs=[mock_faiss_retriever_config]) + assert isinstance(retriever, FAISSRetriever) + + def test_get_retriever_with_multiple_configs_returns_hybrid_retriever( + self, + mock_base_index, + mock_faiss_retriever_config, + mock_bm25_retriever_config, + mock_faiss_vector_store, + mock_storage_context, + mock_vector_store_index, + mock_dynamic_bm25_retriever, + ): + factory = RetrieverFactory() + retriever = factory.get_retriever( + index=mock_base_index, configs=[mock_faiss_retriever_config, mock_bm25_retriever_config] + ) + assert isinstance(retriever, SimpleHybridRetriever) + + def test_get_retriever_with_unknown_config_raises_value_error(self, mock_base_index, mocker): + mock_unknown_config = mocker.MagicMock() + factory = RetrieverFactory() + with pytest.raises(ValueError): + factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config])