diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index dc13adf28..3b6d3fdc9 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -1,5 +1,6 @@ """Simple Engine.""" +import json from typing import Optional from llama_index.core import SimpleDirectoryReader, VectorStoreIndex @@ -33,7 +34,8 @@ from metagpt.rag.factories import ( ) from metagpt.rag.interface import RAGObject from metagpt.rag.llm import get_rag_llm -from metagpt.rag.retrievers.base import ModifiableRAGRetriever +from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever +from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( BaseIndexConfig, BaseRankerConfig, @@ -180,6 +182,12 @@ class SimpleEngine(RetrieverQueryEngine): nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] self._save_nodes(nodes) + def persist(self, persist_dir: str, **kwargs): + """Persist.""" + self._ensure_retriever_persistable() + + self._persist(persist_dir, **kwargs) + @classmethod def _from_index( cls, @@ -200,15 +208,31 @@ class SimpleEngine(RetrieverQueryEngine): ) def _ensure_retriever_modifiable(self): - if not isinstance(self.retriever, ModifiableRAGRetriever): - raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}") + self._ensure_retriever_of_type(ModifiableRAGRetriever) + + def _ensure_retriever_persistable(self): + self._ensure_retriever_of_type(PersistableRAGRetriever) + + def _ensure_retriever_of_type(self, required_type: BaseRetriever): + """Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever. + + Args: + required_type: The class that the retriever is expected to be an instance of. + """ + if isinstance(self.retriever, SimpleHybridRetriever): + if not any(isinstance(r, required_type) for r in self.retriever.retrievers): + raise TypeError( + f"Must have at least one retriever of type {required_type.__name__} in SimpleHybridRetriever" + ) + + if not isinstance(self.retriever, required_type): + raise TypeError(f"The retriever is not of type {required_type.__name__}: {type(self.retriever)}") def _save_nodes(self, nodes: list[BaseNode]): - # for search in memory self.retriever.add_nodes(nodes) - # for persist - self.index.insert_nodes(nodes) + def _persist(self, persist_dir: str, **kwargs): + self.retriever.persist(persist_dir, **kwargs) @staticmethod def _try_reconstruct_obj(nodes: list[NodeWithScore]): @@ -216,7 +240,8 @@ class SimpleEngine(RetrieverQueryEngine): for node in nodes: if node.metadata.get("is_obj", False): obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"]) - node.metadata["obj"] = obj_cls(**node.metadata["obj_dict"]) + obj_dict = json.loads(node.metadata["obj_json"]) + node.metadata["obj"] = obj_cls(**obj_dict) @staticmethod def _fix_document_metadata(documents: list[Document]): diff --git a/metagpt/rag/interface.py b/metagpt/rag/interface.py index 9f5d8375c..9af2c1219 100644 --- a/metagpt/rag/interface.py +++ b/metagpt/rag/interface.py @@ -1,6 +1,6 @@ """RAG Interfaces.""" -from typing import Any, Protocol +from typing import Protocol class RAGObject(Protocol): @@ -9,8 +9,8 @@ class RAGObject(Protocol): def rag_key(self) -> str: """For rag search.""" - def model_dump(self) -> dict[str, Any]: + def model_dump_json(self) -> str: """For rag persist. - Pydantic Model don't need to implement this, as there is a built-in function named model_dump. + Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json. """ diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index ea73a0017..a7b836833 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -31,3 +31,17 @@ class ModifiableRAGRetriever(RAGRetriever): @abstractmethod def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """To support add docs, must inplement this func""" + + +class PersistableRAGRetriever(RAGRetriever): + """Support persistent.""" + + @classmethod + def __subclasshook__(cls, C): + if cls is PersistableRAGRetriever: + return check_methods(C, "persist") + return NotImplemented + + @abstractmethod + def persist(self, persist_dir: str, **kwargs) -> None: + """To support persist, must inplement this func""" diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 2965f685a..68037c31f 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -8,7 +8,7 @@ from rank_bm25 import BM25Okapi class DynamicBM25Retriever(BM25Retriever): """BM25 retriever.""" - def add_nodes(self, nodes: list[BaseNode], **kwargs): + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """Support add nodes""" self._nodes.extend(nodes) self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] diff --git a/metagpt/rag/retrievers/chroma_retriever.py b/metagpt/rag/retrievers/chroma_retriever.py index 7832fa878..d41f375e4 100644 --- a/metagpt/rag/retrievers/chroma_retriever.py +++ b/metagpt/rag/retrievers/chroma_retriever.py @@ -7,6 +7,11 @@ from llama_index.core.schema import BaseNode class ChromaRetriever(VectorIndexRetriever): """Chroma retriever.""" - def add_nodes(self, nodes: list[BaseNode], **kwargs): - """Support add nodes""" + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: + """Support add nodes.""" self._index.insert_nodes(nodes, **kwargs) + + def persist(self, persist_dir: str, **kwargs) -> None: + """Support persist. + + Chromadb automatically saves, so there is no need to implement.""" diff --git a/metagpt/rag/retrievers/faiss_retriever.py b/metagpt/rag/retrievers/faiss_retriever.py index 8c649b53e..7e543cce2 100644 --- a/metagpt/rag/retrievers/faiss_retriever.py +++ b/metagpt/rag/retrievers/faiss_retriever.py @@ -7,6 +7,10 @@ from llama_index.core.schema import BaseNode class FAISSRetriever(VectorIndexRetriever): """FAISS retriever.""" - def add_nodes(self, nodes: list[BaseNode], **kwargs): + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """Support add nodes""" self._index.insert_nodes(nodes, **kwargs) + + def persist(self, persist_dir: str, **kwargs) -> None: + """Support persist.""" + self._index.storage_context.persist(persist_dir) diff --git a/metagpt/rag/retrievers/hybrid_retriever.py b/metagpt/rag/retrievers/hybrid_retriever.py index 14deb6ebf..c725bfc20 100644 --- a/metagpt/rag/retrievers/hybrid_retriever.py +++ b/metagpt/rag/retrievers/hybrid_retriever.py @@ -37,7 +37,12 @@ class SimpleHybridRetriever(RAGRetriever): node_ids.add(n.node.node_id) return result - def add_nodes(self, nodes: list[BaseNode]): - """Support add nodes""" + def add_nodes(self, nodes: list[BaseNode]) -> None: + """Support add nodes.""" for r in self.retrievers: r.add_nodes(nodes) + + def persist(self, persist_dir: str, **kwargs) -> None: + """Support persist.""" + for r in self.retrievers: + r.persist(persist_dir, **kwargs) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 81db2a0d1..d75681a8f 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -93,8 +93,10 @@ class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" is_obj: bool = Field(default=True) - obj: Any = Field(default=None, description="When retrieve, will reconstruct obj from obj_dict") - obj_dict: dict = Field(..., description="Inplement rag.interface.RAGObject.model_dump(), e.g. obj.model_dump()") + obj: Any = Field(default=None, description="When retrieve, will reconstruct obj from obj_json") + obj_json: str = Field( + ..., description="Inplement rag.interface.RAGObject.model_dump_json(), e.g. obj.model_dump_json()" + ) obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__") obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__") @@ -110,6 +112,7 @@ class ObjectNode(TextNode): @staticmethod def get_obj_metadata(obj: RAGObject) -> dict: metadata = ObjectNodeMetadata( - obj_dict=obj.model_dump(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__ + obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__ ) + return metadata.model_dump() diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 4125d480a..5627957c7 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -150,8 +150,8 @@ class TestSimpleEngine: def rag_key(self): return "" - def model_dump(self): - return {} + def model_dump_json(self): + return "" objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)] engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())