diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 2d50b052f..f6b38f9e7 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -47,7 +47,7 @@ class ExperienceManager(BaseModel): exp (Experience): The experience to add. """ - if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_write: + if not self._is_writable(): return self.storage.add_objs([exp]) @@ -66,7 +66,7 @@ class ExperienceManager(BaseModel): list[Experience]: A list of experiences that match the args. """ - if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_read: + if not self._is_readable(): return [] nodes = await self.storage.aretrieve(req) @@ -86,6 +86,21 @@ class ExperienceManager(BaseModel): return self.storage.count() + @handle_exception + def delete_all_exps(self): + """Delete the all experiences.""" + + if not self._is_writable(): + return + + self.storage.clear(persist_dir=self.config.exp_pool.persist_path) + + def _is_readable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_read + + def _is_writable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_write + def _resolve_storage(self) -> "SimpleEngine": """Selects the appropriate storage creation method based on the configured retrieval type.""" diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index be4c3daf5..e48decdab 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -38,6 +38,7 @@ from metagpt.rag.factories import ( ) from metagpt.rag.interface import NoEmbedding, RAGObject from metagpt.rag.retrievers.base import ( + DeletableRAGRetriever, ModifiableRAGRetriever, PersistableRAGRetriever, QueryableRAGRetriever, @@ -218,7 +219,13 @@ class SimpleEngine(RetrieverQueryEngine): """Count.""" self._ensure_retriever_queryable() - return self._retriever.query_total_count() + return self.retriever.query_total_count() + + def clear(self, **kwargs): + """Clear.""" + self._ensure_retriever_deletable() + + return self.retriever.clear(**kwargs) @staticmethod def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]: @@ -277,6 +284,9 @@ class SimpleEngine(RetrieverQueryEngine): def _ensure_retriever_queryable(self): self._ensure_retriever_of_type(QueryableRAGRetriever) + def _ensure_retriever_deletable(self): + self._ensure_retriever_of_type(DeletableRAGRetriever) + def _ensure_retriever_of_type(self, required_type: BaseRetriever): """Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever. diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index 5bd04adca..69475d6ea 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -58,4 +58,18 @@ class QueryableRAGRetriever(RAGRetriever): @abstractmethod def query_total_count(self) -> int: - """To support querying total count, must implement this func""" + """To support querying total count, must implement this func.""" + + +class DeletableRAGRetriever(RAGRetriever): + """Support deleting all nodes.""" + + @classmethod + def __subclasshook__(cls, C): + if cls is DeletableRAGRetriever: + return check_methods(C, "clear") + return NotImplemented + + @abstractmethod + def clear(self, **kwargs) -> int: + """To support deleting all nodes, must implement this func.""" diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index ace1bb86c..74cba5124 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -1,4 +1,5 @@ """BM25 retriever.""" +from pathlib import Path from typing import Callable, Optional from llama_index.core import VectorStoreIndex @@ -52,3 +53,18 @@ class DynamicBM25Retriever(BM25Retriever): """Support query total count.""" return len(self._nodes) + + def clear(self, **kwargs) -> None: + """Support deleting all nodes.""" + self._delete_json_files(kwargs.get("persist_dir")) + self._nodes = [] + + @staticmethod + def _delete_json_files(directory: str): + """Delete all JSON files in the specified directory.""" + + if not directory: + return + + for file in Path(directory).glob("*.json"): + file.unlink() diff --git a/metagpt/rag/retrievers/chroma_retriever.py b/metagpt/rag/retrievers/chroma_retriever.py index 6c466e49f..4d3d4469e 100644 --- a/metagpt/rag/retrievers/chroma_retriever.py +++ b/metagpt/rag/retrievers/chroma_retriever.py @@ -8,6 +8,10 @@ from llama_index.vector_stores.chroma import ChromaVectorStore class ChromaRetriever(VectorIndexRetriever): """Chroma retriever.""" + @property + def vector_store(self) -> ChromaVectorStore: + return self._vector_store + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """Support add nodes.""" self._index.insert_nodes(nodes, **kwargs) @@ -20,6 +24,11 @@ class ChromaRetriever(VectorIndexRetriever): def query_total_count(self) -> int: """Support query total count.""" - vector_store: ChromaVectorStore = self._vector_store + return self.vector_store._collection.count() - return vector_store._collection.count() + def clear(self, **kwargs) -> None: + """Support deleting all nodes.""" + + ids = self.vector_store._collection.get()["ids"] + if ids: + self.vector_store._collection.delete(ids=ids)