diff --git a/examples/example.json b/examples/example.json deleted file mode 100644 index 996cbec3b..000000000 --- a/examples/example.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "source": "Which facial cleanser is good for oily skin?", - "output": "ABC cleanser is preferred by many with oily skin." - }, - { - "source": "Is L'Oreal good to use?", - "output": "L'Oreal is a popular brand with many positive reviews." - } -] \ No newline at end of file diff --git a/examples/example.xlsx b/examples/example.xlsx deleted file mode 100644 index 85fda644e..000000000 Binary files a/examples/example.xlsx and /dev/null differ diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py new file mode 100644 index 000000000..5b47cec62 --- /dev/null +++ b/examples/rag_pipeline.py @@ -0,0 +1,102 @@ +"""RAG pipeline""" +import asyncio + +import faiss +from llama_index import ( + ServiceContext, + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, +) +from llama_index.postprocessor import LLMRerank +from llama_index.retrievers import BM25Retriever, VectorIndexRetriever +from llama_index.vector_stores.faiss import FaissVectorStore + +from metagpt.const import EXAMPLE_PATH +from metagpt.rag.llm import get_default_llm +from metagpt.rag.retrievers import SimpleHybridRetriever +from metagpt.utils.embedding import get_embedding + +DOC_PATH = EXAMPLE_PATH / "data/rag.txt" +QUESTION = "What are key qualities to be a good writer?" +TOPK = 5 + + +def print_result(nodes, extra="retrieve"): + """print retrieve/rerank result""" + print("-" * 50) + print(f"{extra} result") + for i, node in enumerate(nodes): + print(f"{i}. {node.text[:10]}..., {node.score}") + + +async def rag_pipeline(): + """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: + + -------------------------------------------------- + faiss retrieve result + 0. I highly r..., 0.3958844542503357 + 1. I wrote cu..., 0.41629382967948914 + 2. Productivi..., 0.4318419098854065 + 3. Some sort ..., 0.45991092920303345 + -------------------------------------------------- + bm25 retrieve result + 0. I highly r..., 0.19445682103516615 + 1. Some sort ..., 0.18688966233196197 + 2. Productivi..., 0.17071309618829872 + 3. I wrote cu..., 0.15878996566615383 + -------------------------------------------------- + hybrid retrieve result + 0. I highly r..., 0.3958844542503357 + 1. I wrote cu..., 0.41629382967948914 + 2. Productivi..., 0.4318419098854065 + 3. Some sort ..., 0.45991092920303345 + -------------------------------------------------- + llm ranker result + 0. Productivi..., 10.0 + 1. I wrote cu..., 7.0 + 2. I highly r..., 5.0 + """ + # Documents, there are many readers can load documents. + documents = SimpleDirectoryReader(input_files=[DOC_PATH]).load_data() + + # Service Conext, a bundle of resources for llm/embedding/node_parse. + service_context = ServiceContext.from_defaults(llm=get_default_llm(), embed_model=get_embedding()) + + # Nodes, chunks of documents. + node_parser = service_context.node_parser + nodes = node_parser.get_nodes_from_documents(documents) + + # Index-FAISS + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536)) # dimensions of text-ada-embedding-002 + storage_context = StorageContext.from_defaults(vector_store=vector_store) + vector_index = VectorStoreIndex(nodes=nodes, storage_context=storage_context, service_context=service_context) + + # Retriever-FAISS + faiss_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=TOPK) + faiss_retrieve_nodes = await faiss_retriever.aretrieve(QUESTION) + print_result(faiss_retrieve_nodes, extra="faiss retrieve") + + # Retriever-BM25 + bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=TOPK) + bm25_retrieve_nodes = await bm25_retriever.aretrieve(QUESTION) + print_result(bm25_retrieve_nodes, extra="bm25 retrieve") + + # Retriever-Hybrid + hybrid_retriever = SimpleHybridRetriever(faiss_retriever, bm25_retriever) + hybrid_retrieve_nodes = await hybrid_retriever.aretrieve(QUESTION) + print_result(hybrid_retrieve_nodes, extra="hybrid retrieve") + + # Ranker-LLM + llm_ranker = LLMRerank(top_n=TOPK, service_context=service_context) + llm_rank_nodes = llm_ranker.postprocess_nodes(faiss_retrieve_nodes, query_str=QUESTION) + print_result(llm_rank_nodes, extra="llm ranker") + + +async def main(): + """RAG pipeline""" + await rag_pipeline() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/rag_search.py b/examples/rag_search.py new file mode 100644 index 000000000..222573476 --- /dev/null +++ b/examples/rag_search.py @@ -0,0 +1,19 @@ +"""Agent with RAG search""" +import asyncio + +from examples.rag_pipeline import DOC_PATH, QUESTION, TOPK +from metagpt.logs import logger +from metagpt.rag.engines import SimpleEngine +from metagpt.roles import Sales + + +async def search(): + """Agent with RAG search""" + store = SimpleEngine.from_docs(input_files=[DOC_PATH], similarity_top_k=TOPK) + role = Sales(profile="Sales", store=store) + result = await role.run(QUESTION) + logger.info(result) + + +if __name__ == "__main__": + asyncio.run(search()) diff --git a/metagpt/rag/__init__.py b/metagpt/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/engines/__init__.py b/metagpt/rag/engines/__init__.py new file mode 100644 index 000000000..7b4e37e88 --- /dev/null +++ b/metagpt/rag/engines/__init__.py @@ -0,0 +1,3 @@ +from metagpt.rag.engines.simple import SimpleEngine + +__all__ = ["SimpleEngine"] diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py new file mode 100644 index 000000000..7532f6620 --- /dev/null +++ b/metagpt/rag/engines/simple.py @@ -0,0 +1,48 @@ +"""Simple Engine.""" +from typing import Optional + +from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex +from llama_index.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.embeddings.base import BaseEmbedding +from llama_index.llms.llm import LLM +from llama_index.query_engine import RetrieverQueryEngine +from llama_index.retrievers import VectorIndexRetriever + +from metagpt.rag.llm import get_default_llm +from metagpt.utils.embedding import get_embedding + + +class SimpleEngine(RetrieverQueryEngine): + """ + SimpleEngine is a search engine that uses a vector index for retrieving documents. + """ + + @classmethod + def from_docs( + cls, + input_dir: str = None, + input_files: list = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + # node parser kwargs + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + # retrieve kwargs + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + ) -> "SimpleEngine": + """This engine is designed to be simple and straightforward""" + documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() + service_context = ServiceContext.from_defaults( + embed_model=embed_model or get_embedding(), + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + llm=llm or get_default_llm(), + ) + index = VectorStoreIndex.from_documents(documents, service_context=service_context) + retriever = VectorIndexRetriever(index=index, similarity_top_k=similarity_top_k) + + return SimpleEngine(retriever=retriever) + + async def asearch(self, content: str, **kwargs) -> str: + """Inplement tools.SearchInterface""" + return await self.aquery(content) diff --git a/metagpt/rag/llm.py b/metagpt/rag/llm.py new file mode 100644 index 000000000..e67be1416 --- /dev/null +++ b/metagpt/rag/llm.py @@ -0,0 +1,7 @@ +from llama_index.llms import OpenAI + +from metagpt.config2 import config + + +def get_default_llm() -> OpenAI: + return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key) diff --git a/metagpt/rag/rankers/__init__.py b/metagpt/rag/rankers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/rankers/base.py b/metagpt/rag/rankers/base.py new file mode 100644 index 000000000..482fc4aef --- /dev/null +++ b/metagpt/rag/rankers/base.py @@ -0,0 +1,20 @@ +"""Base Ranker.""" + +from abc import abstractmethod +from typing import Optional + +from llama_index import QueryBundle +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import NodeWithScore + + +class RAGRanker(BaseNodePostprocessor): + """inherit from llama_index""" + + @abstractmethod + def _postprocess_nodes( + self, + nodes: list[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> list[NodeWithScore]: + """postprocess nodes.""" diff --git a/metagpt/rag/retrievers/__init__.py b/metagpt/rag/retrievers/__init__.py new file mode 100644 index 000000000..799766870 --- /dev/null +++ b/metagpt/rag/retrievers/__init__.py @@ -0,0 +1,4 @@ +"""init""" +from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever + +__all__ = ["SimpleHybridRetriever"] diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py new file mode 100644 index 000000000..c0291f217 --- /dev/null +++ b/metagpt/rag/retrievers/base.py @@ -0,0 +1,18 @@ +"""Base retriever.""" + + +from abc import abstractmethod + +from llama_index.retrievers import BaseRetriever +from llama_index.schema import NodeWithScore, QueryType + + +class RAGRetriever(BaseRetriever): + """inherit from llama_index""" + + @abstractmethod + async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]: + """retrieve nodes""" + + def _retrieve(self, query: QueryType) -> list[NodeWithScore]: + """retrieve nodes""" diff --git a/metagpt/rag/retrievers/hybrid.py b/metagpt/rag/retrievers/hybrid.py new file mode 100644 index 000000000..e6b526b38 --- /dev/null +++ b/metagpt/rag/retrievers/hybrid.py @@ -0,0 +1,36 @@ +"""Hybrid retriever.""" +from llama_index.schema import QueryType + +from metagpt.rag.retrievers.base import RAGRetriever + + +class SimpleHybridRetriever(RAGRetriever): + """ + SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers. + """ + + def __init__(self, *retrievers): + self.retrievers: list[RAGRetriever] = retrievers + super().__init__() + + async def _aretrieve(self, query: QueryType, **kwargs): + """ + Asynchronously retrieves and aggregates search results from all configured retrievers. + + This method queries each retriever in the `retrievers` list with the given query and + additional keyword arguments. It then combines the results, ensuring that each node is + unique, based on the node's ID. + """ + all_nodes = [] + for retriever in self.retrievers: + nodes = await retriever.aretrieve(query, **kwargs) + all_nodes.extend(nodes) + + # combine all nodes + result = [] + node_ids = set() + for n in all_nodes: + if n.node.node_id not in node_ids: + result.append(n) + node_ids.add(n.node.node_id) + return result diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index bc449b5cd..e5cb12778 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -11,7 +11,6 @@ from typing import Optional from pydantic import Field, model_validator from metagpt.actions import SearchAndSummarize, UserRequirement -from metagpt.document_store.base_store import BaseStore from metagpt.roles import Role from metagpt.tools.search_engine import SearchEngine @@ -27,7 +26,7 @@ class Sales(Role): "delivered with the professionalism and courtesy expected of a seasoned sales guide." ) - store: Optional[BaseStore] = Field(default=None, exclude=True) + store: Optional[object] = Field(default=None, exclude=True) # must inplement tools.SearchInterface @model_validator(mode="after") def validate_stroe(self): diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index c1f604df9..8d265e9f3 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -30,3 +30,8 @@ class WebBrowserEngineType(Enum): def __missing__(cls, key): """Default type conversion""" return cls.CUSTOM + + +class SearchInterface: + async def asearch(self, *args, **kwargs): + ... diff --git a/tests/metagpt/document_store/test_faiss_store.py b/tests/metagpt/document_store/test_faiss_store.py index 1a159b413..0c5a55e0f 100644 --- a/tests/metagpt/document_store/test_faiss_store.py +++ b/tests/metagpt/document_store/test_faiss_store.py @@ -28,7 +28,7 @@ def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int async def test_search_json(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) - store = FaissStore(EXAMPLE_PATH / "example.json") + store = FaissStore(EXAMPLE_PATH / "data/example.json") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" result = await role.run(query) @@ -39,7 +39,7 @@ async def test_search_json(mocker): async def test_search_xlsx(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) - store = FaissStore(EXAMPLE_PATH / "example.xlsx") + store = FaissStore(EXAMPLE_PATH / "data/example.xlsx") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" result = await role.run(query) @@ -50,7 +50,7 @@ async def test_search_xlsx(mocker): async def test_write(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) - store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question") + store = FaissStore(EXAMPLE_PATH / "data/example.xlsx", meta_col="Answer", content_col="Question") _faiss_store = store.write() assert _faiss_store.storage_context.docstore assert _faiss_store.storage_context.vector_store.client diff --git a/tests/metagpt/rag/engine/test_simple.py b/tests/metagpt/rag/engine/test_simple.py new file mode 100644 index 000000000..4eb1d0b6d --- /dev/null +++ b/tests/metagpt/rag/engine/test_simple.py @@ -0,0 +1,67 @@ +from unittest.mock import AsyncMock + +import pytest + +from metagpt.rag import SimpleEngine + + +class TestSimpleEngineFromDocs: + def test_from_docs(self, mocker): + # Mock dependencies + mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") + mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"] + + mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults") + mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") + mock_vector_index_retriever = mocker.patch("metagpt.rag.engines.simple.VectorIndexRetriever") + + # Setup + input_dir = "test_dir" + input_files = ["test_file1", "test_file2"] + embed_model = mocker.MagicMock() + llm = mocker.MagicMock() + chunk_size = 100 + chunk_overlap = 10 + similarity_top_k = 5 + + # Execute + engine = SimpleEngine.from_docs( + input_dir=input_dir, + input_files=input_files, + embed_model=embed_model, + llm=llm, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + similarity_top_k=similarity_top_k, + ) + + # Assertions + mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) + mock_service_context.assert_called_once_with( + embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm + ) + mock_vector_store_index.assert_called_once_with( + ["document1", "document2"], service_context=mock_service_context.return_value + ) + mock_vector_index_retriever.assert_called_once_with( + index=mock_vector_store_index.return_value, similarity_top_k=similarity_top_k + ) + assert isinstance(engine, SimpleEngine) + + @pytest.mark.asyncio + async def test_asearch_calls_aquery(self, mocker): + # Mock + test_query = "test query" + expected_result = "expected result" + mock_aquery = AsyncMock(return_value=expected_result) + + # Setup + engine = SimpleEngine(retriever=mocker.MagicMock()) + engine.aquery = mock_aquery + + # Execute + result = await engine.asearch(test_query) + + # Assertions + mock_aquery.assert_called_once_with(test_query) + assert result == expected_result diff --git a/tests/metagpt/rag/retrievers/test_hybrid_retriever.py b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py new file mode 100644 index 000000000..62d976ba2 --- /dev/null +++ b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py @@ -0,0 +1,39 @@ +from unittest.mock import AsyncMock + +import pytest +from llama_index.schema import NodeWithScore, TextNode + +from metagpt.rag.retrievers import SimpleHybridRetriever + + +class TestSimpleHybridRetriever: + @pytest.mark.asyncio + async def test_aretrieve(self): + question = "test query" + + # Create mock retrievers + mock_retriever1 = AsyncMock() + mock_retriever1.aretrieve.return_value = [ + NodeWithScore(node=TextNode(id_="1"), score=1.0), + NodeWithScore(node=TextNode(id_="2"), score=0.95), + ] + + mock_retriever2 = AsyncMock() + mock_retriever2.aretrieve.return_value = [ + NodeWithScore(node=TextNode(id_="2"), score=0.95), + NodeWithScore(node=TextNode(id_="3"), score=0.8), + ] + + # Instantiate the SimpleHybridRetriever with the mock retrievers + hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2) + + # Call the _aretrieve method + results = await hybrid_retriever._aretrieve(question) + + # Check if the results are as expected + assert len(results) == 3 # Should be 3 unique nodes + assert set(node.node.node_id for node in results) == {"1", "2", "3"} + + # Check if the scores are correct (assuming you want the highest score) + node_scores = {node.node.node_id: node.score for node in results} + assert node_scores["2"] == 0.95