add persist

This commit is contained in:
seehi 2024-03-11 20:18:27 +08:00
parent 6a388b53f1
commit 0576ab2ed1
9 changed files with 77 additions and 21 deletions

View file

@ -1,5 +1,6 @@
"""Simple Engine."""
import json
from typing import Optional
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
@ -33,7 +34,8 @@ from metagpt.rag.factories import (
)
from metagpt.rag.interface import RAGObject
from metagpt.rag.llm import get_rag_llm
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseIndexConfig,
BaseRankerConfig,
@ -180,6 +182,12 @@ class SimpleEngine(RetrieverQueryEngine):
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
self._save_nodes(nodes)
def persist(self, persist_dir: str, **kwargs):
"""Persist."""
self._ensure_retriever_persistable()
self._persist(persist_dir, **kwargs)
@classmethod
def _from_index(
cls,
@ -200,15 +208,31 @@ class SimpleEngine(RetrieverQueryEngine):
)
def _ensure_retriever_modifiable(self):
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")
self._ensure_retriever_of_type(ModifiableRAGRetriever)
def _ensure_retriever_persistable(self):
self._ensure_retriever_of_type(PersistableRAGRetriever)
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
Args:
required_type: The class that the retriever is expected to be an instance of.
"""
if isinstance(self.retriever, SimpleHybridRetriever):
if not any(isinstance(r, required_type) for r in self.retriever.retrievers):
raise TypeError(
f"Must have at least one retriever of type {required_type.__name__} in SimpleHybridRetriever"
)
if not isinstance(self.retriever, required_type):
raise TypeError(f"The retriever is not of type {required_type.__name__}: {type(self.retriever)}")
def _save_nodes(self, nodes: list[BaseNode]):
# for search in memory
self.retriever.add_nodes(nodes)
# for persist
self.index.insert_nodes(nodes)
def _persist(self, persist_dir: str, **kwargs):
self.retriever.persist(persist_dir, **kwargs)
@staticmethod
def _try_reconstruct_obj(nodes: list[NodeWithScore]):
@ -216,7 +240,8 @@ class SimpleEngine(RetrieverQueryEngine):
for node in nodes:
if node.metadata.get("is_obj", False):
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
node.metadata["obj"] = obj_cls(**node.metadata["obj_dict"])
obj_dict = json.loads(node.metadata["obj_json"])
node.metadata["obj"] = obj_cls(**obj_dict)
@staticmethod
def _fix_document_metadata(documents: list[Document]):

View file

@ -1,6 +1,6 @@
"""RAG Interfaces."""
from typing import Any, Protocol
from typing import Protocol
class RAGObject(Protocol):
@ -9,8 +9,8 @@ class RAGObject(Protocol):
def rag_key(self) -> str:
"""For rag search."""
def model_dump(self) -> dict[str, Any]:
def model_dump_json(self) -> str:
"""For rag persist.
Pydantic Model don't need to implement this, as there is a built-in function named model_dump.
Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json.
"""

View file

@ -31,3 +31,17 @@ class ModifiableRAGRetriever(RAGRetriever):
@abstractmethod
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""To support add docs, must inplement this func"""
class PersistableRAGRetriever(RAGRetriever):
"""Support persistent."""
@classmethod
def __subclasshook__(cls, C):
if cls is PersistableRAGRetriever:
return check_methods(C, "persist")
return NotImplemented
@abstractmethod
def persist(self, persist_dir: str, **kwargs) -> None:
"""To support persist, must inplement this func"""

View file

@ -8,7 +8,7 @@ from rank_bm25 import BM25Okapi
class DynamicBM25Retriever(BM25Retriever):
"""BM25 retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs):
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes"""
self._nodes.extend(nodes)
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]

View file

@ -7,6 +7,11 @@ from llama_index.core.schema import BaseNode
class ChromaRetriever(VectorIndexRetriever):
"""Chroma retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs):
"""Support add nodes"""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Chromadb automatically saves, so there is no need to implement."""

View file

@ -7,6 +7,10 @@ from llama_index.core.schema import BaseNode
class FAISSRetriever(VectorIndexRetriever):
"""FAISS retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs):
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes"""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)

View file

@ -37,7 +37,12 @@ class SimpleHybridRetriever(RAGRetriever):
node_ids.add(n.node.node_id)
return result
def add_nodes(self, nodes: list[BaseNode]):
"""Support add nodes"""
def add_nodes(self, nodes: list[BaseNode]) -> None:
"""Support add nodes."""
for r in self.retrievers:
r.add_nodes(nodes)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
for r in self.retrievers:
r.persist(persist_dir, **kwargs)

View file

@ -93,8 +93,10 @@ class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""
is_obj: bool = Field(default=True)
obj: Any = Field(default=None, description="When retrieve, will reconstruct obj from obj_dict")
obj_dict: dict = Field(..., description="Inplement rag.interface.RAGObject.model_dump(), e.g. obj.model_dump()")
obj: Any = Field(default=None, description="When retrieve, will reconstruct obj from obj_json")
obj_json: str = Field(
..., description="Inplement rag.interface.RAGObject.model_dump_json(), e.g. obj.model_dump_json()"
)
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
@ -110,6 +112,7 @@ class ObjectNode(TextNode):
@staticmethod
def get_obj_metadata(obj: RAGObject) -> dict:
metadata = ObjectNodeMetadata(
obj_dict=obj.model_dump(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
)
return metadata.model_dump()

View file

@ -150,8 +150,8 @@ class TestSimpleEngine:
def rag_key(self):
return ""
def model_dump(self):
return {}
def model_dump_json(self):
return ""
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())