diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 1b8a63434..e036f6aa9 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -94,14 +94,19 @@ class SimpleEngine(RetrieverQueryEngine): def add_docs(self, input_files: list[str]): """Add docs to retriever. retriever must has add_nodes func.""" - if not isinstance(self.retriever, ModifiableRAGRetriever): - raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}") + self._ensure_retriever_modifiable() documents = SimpleDirectoryReader(input_files=input_files).load_data() nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents) self.retriever.add_nodes(nodes) - def add_objs(self, obj_list: list[RAGObject]): + def add_objs(self, objs: list[RAGObject]): """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" - nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in obj_list] + self._ensure_retriever_modifiable() + + nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in objs] self.retriever.add_nodes(nodes) + + def _ensure_retriever_modifiable(self): + if not isinstance(self.retriever, ModifiableRAGRetriever): + raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")