mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add persist
This commit is contained in:
parent
6a388b53f1
commit
0576ab2ed1
9 changed files with 77 additions and 21 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""Simple Engine."""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
|
|
@ -33,7 +34,8 @@ from metagpt.rag.factories import (
|
|||
)
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.llm import get_rag_llm
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
BaseRankerConfig,
|
||||
|
|
@ -180,6 +182,12 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs):
|
||||
"""Persist."""
|
||||
self._ensure_retriever_persistable()
|
||||
|
||||
self._persist(persist_dir, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
|
|
@ -200,15 +208,31 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
if not isinstance(self.retriever, ModifiableRAGRetriever):
|
||||
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")
|
||||
self._ensure_retriever_of_type(ModifiableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_persistable(self):
|
||||
self._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
|
||||
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
|
||||
|
||||
Args:
|
||||
required_type: The class that the retriever is expected to be an instance of.
|
||||
"""
|
||||
if isinstance(self.retriever, SimpleHybridRetriever):
|
||||
if not any(isinstance(r, required_type) for r in self.retriever.retrievers):
|
||||
raise TypeError(
|
||||
f"Must have at least one retriever of type {required_type.__name__} in SimpleHybridRetriever"
|
||||
)
|
||||
|
||||
if not isinstance(self.retriever, required_type):
|
||||
raise TypeError(f"The retriever is not of type {required_type.__name__}: {type(self.retriever)}")
|
||||
|
||||
def _save_nodes(self, nodes: list[BaseNode]):
|
||||
# for search in memory
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
||||
# for persist
|
||||
self.index.insert_nodes(nodes)
|
||||
def _persist(self, persist_dir: str, **kwargs):
|
||||
self.retriever.persist(persist_dir, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _try_reconstruct_obj(nodes: list[NodeWithScore]):
|
||||
|
|
@ -216,7 +240,8 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
for node in nodes:
|
||||
if node.metadata.get("is_obj", False):
|
||||
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
|
||||
node.metadata["obj"] = obj_cls(**node.metadata["obj_dict"])
|
||||
obj_dict = json.loads(node.metadata["obj_json"])
|
||||
node.metadata["obj"] = obj_cls(**obj_dict)
|
||||
|
||||
@staticmethod
|
||||
def _fix_document_metadata(documents: list[Document]):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""RAG Interfaces."""
|
||||
|
||||
from typing import Any, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class RAGObject(Protocol):
|
||||
|
|
@ -9,8 +9,8 @@ class RAGObject(Protocol):
|
|||
def rag_key(self) -> str:
|
||||
"""For rag search."""
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
def model_dump_json(self) -> str:
|
||||
"""For rag persist.
|
||||
|
||||
Pydantic Model don't need to implement this, as there is a built-in function named model_dump.
|
||||
Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -31,3 +31,17 @@ class ModifiableRAGRetriever(RAGRetriever):
|
|||
@abstractmethod
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""To support add docs, must inplement this func"""
|
||||
|
||||
|
||||
class PersistableRAGRetriever(RAGRetriever):
|
||||
"""Support persistent."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is PersistableRAGRetriever:
|
||||
return check_methods(C, "persist")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""To support persist, must inplement this func"""
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from rank_bm25 import BM25Okapi
|
|||
class DynamicBM25Retriever(BM25Retriever):
|
||||
"""BM25 retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes"""
|
||||
self._nodes.extend(nodes)
|
||||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
|
|
|
|||
|
|
@ -7,6 +7,11 @@ from llama_index.core.schema import BaseNode
|
|||
class ChromaRetriever(VectorIndexRetriever):
|
||||
"""Chroma retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
"""Support add nodes"""
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist.
|
||||
|
||||
Chromadb automatically saves, so there is no need to implement."""
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@ from llama_index.core.schema import BaseNode
|
|||
class FAISSRetriever(VectorIndexRetriever):
|
||||
"""FAISS retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes"""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,12 @@ class SimpleHybridRetriever(RAGRetriever):
|
|||
node_ids.add(n.node.node_id)
|
||||
return result
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode]):
|
||||
"""Support add nodes"""
|
||||
def add_nodes(self, nodes: list[BaseNode]) -> None:
|
||||
"""Support add nodes."""
|
||||
for r in self.retrievers:
|
||||
r.add_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
for r in self.retrievers:
|
||||
r.persist(persist_dir, **kwargs)
|
||||
|
|
|
|||
|
|
@ -93,8 +93,10 @@ class ObjectNodeMetadata(BaseModel):
|
|||
"""Metadata of ObjectNode."""
|
||||
|
||||
is_obj: bool = Field(default=True)
|
||||
obj: Any = Field(default=None, description="When retrieve, will reconstruct obj from obj_dict")
|
||||
obj_dict: dict = Field(..., description="Inplement rag.interface.RAGObject.model_dump(), e.g. obj.model_dump()")
|
||||
obj: Any = Field(default=None, description="When retrieve, will reconstruct obj from obj_json")
|
||||
obj_json: str = Field(
|
||||
..., description="Inplement rag.interface.RAGObject.model_dump_json(), e.g. obj.model_dump_json()"
|
||||
)
|
||||
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
|
||||
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
|
||||
|
||||
|
|
@ -110,6 +112,7 @@ class ObjectNode(TextNode):
|
|||
@staticmethod
|
||||
def get_obj_metadata(obj: RAGObject) -> dict:
|
||||
metadata = ObjectNodeMetadata(
|
||||
obj_dict=obj.model_dump(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
|
||||
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
|
||||
)
|
||||
|
||||
return metadata.model_dump()
|
||||
|
|
|
|||
|
|
@ -150,8 +150,8 @@ class TestSimpleEngine:
|
|||
def rag_key(self):
|
||||
return ""
|
||||
|
||||
def model_dump(self):
|
||||
return {}
|
||||
def model_dump_json(self):
|
||||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue