diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 045d2162a..3eb1dfd9e 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -61,10 +61,10 @@ class RAGExample: self._print_title("RAG Pipeline") nodes = await self.engine.aretrieve(question) - self._print_result(nodes, state="Retrieve") + self._print_retrieve_result(nodes) answer = await self.engine.aquery(question) - self._print_result(answer, state="Query") + self._print_query_result(answer) async def rag_add_docs(self): """This example show how to add docs, before add docs llm anwser I don't know, after add docs llm give the correct answer, will print something like: @@ -160,28 +160,32 @@ class RAGExample: # query answer = engine.query(TRAVEL_QUESTION) - self._print_result(answer, state="Query") + self._print_query_result(answer) @staticmethod def _print_title(title): logger.info(f"{'#'*30} {title} {'#'*30}") @staticmethod - def _print_result(result, state="Retrieve"): - """print retrieve or query result""" - logger.info(f"{state} Result:") + def _print_retrieve_result(result): + """Print retrieve result.""" + logger.info("Retrieve Result:") - if state == "Retrieve": - for i, node in enumerate(result): - logger.info(f"{i}. {node.text[:10]}..., {node.score}") - logger.info("") - return + for i, node in enumerate(result): + logger.info(f"{i}. {node.text[:10]}..., {node.score}") + + logger.info("") + + @staticmethod + def _print_query_result(result): + """Print query result.""" + logger.info("Query Result:") logger.info(f"{result}\n") async def _retrieve_and_print(self, question): nodes = await self.engine.aretrieve(question) - self._print_result(nodes, state="Retrieve") + self._print_retrieve_result(nodes) return nodes diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 3b6d3fdc9..ebe467ecf 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -1,7 +1,8 @@ """Simple Engine.""" import json -from typing import Optional +import os +from typing import Optional, Union from llama_index.core import SimpleDirectoryReader, VectorStoreIndex from llama_index.core.callbacks.base import CallbackManager @@ -128,8 +129,8 @@ class SimpleEngine(RetrieverQueryEngine): retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever. ranker_configs: Configuration for rankers. """ - if not retriever_configs or any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs): - raise ValueError("Must provide retriever_configs, and BM25RetrieverConfig is not supported.") + if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs): + raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") objs = objs or [] nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] @@ -182,11 +183,11 @@ class SimpleEngine(RetrieverQueryEngine): nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] self._save_nodes(nodes) - def persist(self, persist_dir: str, **kwargs): + def persist(self, persist_dir: Union[str, os.PathLike], **kwargs): """Persist.""" self._ensure_retriever_persistable() - self._persist(persist_dir, **kwargs) + self._persist(str(persist_dir), **kwargs) @classmethod def _from_index( diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index 504faafc6..e6c87c64a 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -7,7 +7,12 @@ from llama_index.core.indices.base import BaseIndex from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.rag.factories.base import ConfigFactory -from metagpt.rag.schema import BaseIndexConfig, ChromaIndexConfig, FAISSIndexConfig +from metagpt.rag.schema import ( + BaseIndexConfig, + BM25IndexConfig, + ChromaIndexConfig, + FAISSIndexConfig, +) from metagpt.rag.vector_stores.chroma import ChromaVectorStore @@ -16,6 +21,7 @@ class RAGIndexFactory(ConfigFactory): creators = { FAISSIndexConfig: self._create_faiss, ChromaIndexConfig: self._create_chroma, + BM25IndexConfig: self._create_bm25, } super().__init__(creators) @@ -46,5 +52,12 @@ class RAGIndexFactory(ConfigFactory): ) return index + def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: + embed_model = self._extract_embed_model(config, **kwargs) + + storage_context = StorageContext.from_defaults(persist_dir=config.persist_path) + index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) + return index + get_index = RAGIndexFactory().get_index diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index d9ec6b12d..2581cbef0 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -1,5 +1,7 @@ """RAG Retriever Factory.""" +import copy + import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex @@ -69,8 +71,9 @@ class RetrieverFactory(ConfigFactory): return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: - config.index = self._extract_index(config, **kwargs) - return DynamicBM25Retriever.from_defaults(**config.model_dump()) + config.index = copy.deepcopy(self._extract_index(config, **kwargs)) + nodes = list(config.index.docstore.docs.values()) + return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: db = chromadb.PersistentClient(path=str(config.persist_path)) diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 68037c31f..241820cf4 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -1,6 +1,10 @@ """BM25 retriever.""" +from typing import Callable, Optional -from llama_index.core.schema import BaseNode +from llama_index.core import VectorStoreIndex +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.core.schema import BaseNode, IndexNode from llama_index.retrievers.bm25 import BM25Retriever from rank_bm25 import BM25Okapi @@ -8,8 +12,36 @@ from rank_bm25 import BM25Okapi class DynamicBM25Retriever(BM25Retriever): """BM25 retriever.""" + def __init__( + self, + nodes: list[BaseNode], + tokenizer: Optional[Callable[[str], list[str]]] = None, + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + callback_manager: Optional[CallbackManager] = None, + objects: Optional[list[IndexNode]] = None, + object_map: Optional[dict] = None, + verbose: bool = False, + index: VectorStoreIndex = None, + ) -> None: + super().__init__( + nodes=nodes, + tokenizer=tokenizer, + similarity_top_k=similarity_top_k, + callback_manager=callback_manager, + object_map=object_map, + objects=objects, + verbose=verbose, + ) + self._index = index + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: - """Support add nodes""" + """Support add nodes.""" self._nodes.extend(nodes) self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] self.bm25 = BM25Okapi(self._corpus) + + self._index.insert_nodes(nodes, **kwargs) + + def persist(self, persist_dir: str, **kwargs) -> None: + """Support persist.""" + self._index.storage_context.persist(persist_dir) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index d75681a8f..ade4b3def 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -89,6 +89,10 @@ class ChromaIndexConfig(VectorIndexConfig): collection_name: str = Field(default="metagpt", description="The name of the collection.") +class BM25IndexConfig(BaseIndexConfig): + """Config for bm25-based index.""" + + class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" diff --git a/tests/metagpt/rag/retrievers/test_bm25_retriever.py b/tests/metagpt/rag/retrievers/test_bm25_retriever.py index 77a1db495..28b37c86b 100644 --- a/tests/metagpt/rag/retrievers/test_bm25_retriever.py +++ b/tests/metagpt/rag/retrievers/test_bm25_retriever.py @@ -1,4 +1,5 @@ import pytest +from llama_index.core import VectorStoreIndex from llama_index.core.schema import Node from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever @@ -14,13 +15,16 @@ class TestDynamicBM25Retriever: self.doc2.get_content.return_value = "Document content 2" self.mock_nodes = [self.doc1, self.doc2] + # 模拟index + index = mocker.MagicMock(spec=VectorStoreIndex) + # 模拟nodes和tokenizer参数 mock_nodes = [] mock_tokenizer = mocker.MagicMock() self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) # 初始化DynamicBM25Retriever对象,并提供必需的参数 - self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer) + self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index) def test_add_docs_updates_nodes_and_corpus(self): # Execute