From baa30d3ced14fb4b1ef41e61a7e162954ac70b06 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 7 Mar 2024 17:56:35 +0800 Subject: [PATCH] from objs --- examples/rag_pipeline.py | 11 +++++ metagpt/rag/engines/simple.py | 84 ++++++++++++++++++++++++++--------- 2 files changed, 73 insertions(+), 22 deletions(-) diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index daf4014fc..68b6a3741 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -121,6 +121,16 @@ class RAGExample: player: Player = nodes[0].metadata["obj"] print(player.name) + async def rag_ini_objs(self): + """This example show how to from objs, will print something like: + + Same as rag_add_objs + """ + pre_engine = self.engine + self.engine = SimpleEngine.from_objs(retriever_configs=[FAISSRetrieverConfig()]) + await self.rag_add_objs() + self.engine = pre_engine + async def rag_chromadb(self): """This example show how to use chromadb. how to save and load index. will print something like: @@ -174,6 +184,7 @@ async def main(): await e.rag_pipeline() await e.rag_add_docs() await e.rag_add_objs() + await e.rag_ini_objs() await e.rag_chromadb() diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 469acbacf..5f6fa01ad 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -38,14 +38,16 @@ from metagpt.rag.schema import ( BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig, + BM25RetrieverConfig, ObjectNode, ) from metagpt.utils.common import import_class class SimpleEngine(RetrieverQueryEngine): - """ - SimpleEngine is a lightweight and easy-to-use search engine that integrates + """SimpleEngine is designed to be simple and straightforward. + + It is a lightweight and easy-to-use search engine that integrates document reading, embedding, indexing, retrieving, and ranking functionalities into a single, straightforward workflow. It is designed to quickly set up a search engine from a collection of documents. @@ -78,7 +80,9 @@ class SimpleEngine(RetrieverQueryEngine): retriever_configs: list[BaseRetrieverConfig] = None, ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": - """This engine is designed to be simple and straightforward + """From docs. + + Must provide either `input_dir` or `input_files`. Args: input_dir: Path to the directory. @@ -89,6 +93,9 @@ class SimpleEngine(RetrieverQueryEngine): retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever. ranker_configs: Configuration for rankers. """ + if not input_dir and not input_files: + raise ValueError("Must provide either `input_dir` or `input_files`.") + documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() index = VectorStoreIndex.from_documents( documents=documents, @@ -97,6 +104,39 @@ class SimpleEngine(RetrieverQueryEngine): ) return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) + @classmethod + def from_objs( + cls, + objs: Optional[list[RAGObject]] = None, + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, + ) -> "SimpleEngine": + """From objs. + + Args: + objs: List of RAGObject. + transformations: Parse documents to nodes. Default [SentenceSplitter]. + embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding. + llm: Must supported by llama index. Default OpenAI. + retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever. + ranker_configs: Configuration for rankers. + """ + # check + if not retriever_configs or any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs): + raise ValueError("Must provide retriever_configs, and BM25RetrieverConfig is not supported.") + + objs = objs or [] + nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + index = VectorStoreIndex( + nodes=nodes, + transformations=transformations or [SentenceSplitter()], + embed_model=embed_model or get_rag_embedding(), + ) + return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) + @classmethod def from_index( cls, @@ -110,25 +150,6 @@ class SimpleEngine(RetrieverQueryEngine): index = get_index(index_config, embed_model=embed_model or get_rag_embedding()) return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) - @classmethod - def _from_index( - cls, - index: BaseIndex, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, - ) -> "SimpleEngine": - llm = llm or get_rag_llm() - retriever = get_retriever(configs=retriever_configs, index=index) - rankers = get_rankers(configs=ranker_configs, llm=llm) - - return cls( - retriever=retriever, - node_postprocessors=rankers, - response_synthesizer=get_response_synthesizer(llm=llm), - index=index, - ) - async def asearch(self, content: str, **kwargs) -> str: """Inplement tools.SearchInterface""" return await self.aquery(content) @@ -156,6 +177,25 @@ class SimpleEngine(RetrieverQueryEngine): nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] self._save_nodes(nodes) + @classmethod + def _from_index( + cls, + index: BaseIndex, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, + ) -> "SimpleEngine": + llm = llm or get_rag_llm() + retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever + rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] + + return cls( + retriever=retriever, + node_postprocessors=rankers, + response_synthesizer=get_response_synthesizer(llm=llm), + index=index, + ) + def _ensure_retriever_modifiable(self): if not isinstance(self.retriever, ModifiableRAGRetriever): raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")