mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
exp_pool add delete_all_exps
This commit is contained in:
parent
9d327081ba
commit
d3199604a2
5 changed files with 70 additions and 6 deletions
|
|
@ -47,7 +47,7 @@ class ExperienceManager(BaseModel):
|
|||
exp (Experience): The experience to add.
|
||||
"""
|
||||
|
||||
if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_write:
|
||||
if not self._is_writable():
|
||||
return
|
||||
|
||||
self.storage.add_objs([exp])
|
||||
|
|
@ -66,7 +66,7 @@ class ExperienceManager(BaseModel):
|
|||
list[Experience]: A list of experiences that match the args.
|
||||
"""
|
||||
|
||||
if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_read:
|
||||
if not self._is_readable():
|
||||
return []
|
||||
|
||||
nodes = await self.storage.aretrieve(req)
|
||||
|
|
@ -86,6 +86,21 @@ class ExperienceManager(BaseModel):
|
|||
|
||||
return self.storage.count()
|
||||
|
||||
@handle_exception
|
||||
def delete_all_exps(self):
|
||||
"""Delete the all experiences."""
|
||||
|
||||
if not self._is_writable():
|
||||
return
|
||||
|
||||
self.storage.clear(persist_dir=self.config.exp_pool.persist_path)
|
||||
|
||||
def _is_readable(self) -> bool:
|
||||
return self.config.exp_pool.enabled and self.config.exp_pool.enable_read
|
||||
|
||||
def _is_writable(self) -> bool:
|
||||
return self.config.exp_pool.enabled and self.config.exp_pool.enable_write
|
||||
|
||||
def _resolve_storage(self) -> "SimpleEngine":
|
||||
"""Selects the appropriate storage creation method based on the configured retrieval type."""
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from metagpt.rag.factories import (
|
|||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.retrievers.base import (
|
||||
DeletableRAGRetriever,
|
||||
ModifiableRAGRetriever,
|
||||
PersistableRAGRetriever,
|
||||
QueryableRAGRetriever,
|
||||
|
|
@ -218,7 +219,13 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Count."""
|
||||
self._ensure_retriever_queryable()
|
||||
|
||||
return self._retriever.query_total_count()
|
||||
return self.retriever.query_total_count()
|
||||
|
||||
def clear(self, **kwargs):
|
||||
"""Clear."""
|
||||
self._ensure_retriever_deletable()
|
||||
|
||||
return self.retriever.clear(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
|
||||
|
|
@ -277,6 +284,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
def _ensure_retriever_queryable(self):
|
||||
self._ensure_retriever_of_type(QueryableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_deletable(self):
|
||||
self._ensure_retriever_of_type(DeletableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
|
||||
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
|
||||
|
||||
|
|
|
|||
|
|
@ -58,4 +58,18 @@ class QueryableRAGRetriever(RAGRetriever):
|
|||
|
||||
@abstractmethod
|
||||
def query_total_count(self) -> int:
|
||||
"""To support querying total count, must implement this func"""
|
||||
"""To support querying total count, must implement this func."""
|
||||
|
||||
|
||||
class DeletableRAGRetriever(RAGRetriever):
|
||||
"""Support deleting all nodes."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is DeletableRAGRetriever:
|
||||
return check_methods(C, "clear")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, **kwargs) -> int:
|
||||
"""To support deleting all nodes, must implement this func."""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""BM25 retriever."""
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
|
@ -52,3 +53,18 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
"""Support query total count."""
|
||||
|
||||
return len(self._nodes)
|
||||
|
||||
def clear(self, **kwargs) -> None:
|
||||
"""Support deleting all nodes."""
|
||||
self._delete_json_files(kwargs.get("persist_dir"))
|
||||
self._nodes = []
|
||||
|
||||
@staticmethod
|
||||
def _delete_json_files(directory: str):
|
||||
"""Delete all JSON files in the specified directory."""
|
||||
|
||||
if not directory:
|
||||
return
|
||||
|
||||
for file in Path(directory).glob("*.json"):
|
||||
file.unlink()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ from llama_index.vector_stores.chroma import ChromaVectorStore
|
|||
class ChromaRetriever(VectorIndexRetriever):
|
||||
"""Chroma retriever."""
|
||||
|
||||
@property
|
||||
def vector_store(self) -> ChromaVectorStore:
|
||||
return self._vector_store
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
|
@ -20,6 +24,11 @@ class ChromaRetriever(VectorIndexRetriever):
|
|||
def query_total_count(self) -> int:
|
||||
"""Support query total count."""
|
||||
|
||||
vector_store: ChromaVectorStore = self._vector_store
|
||||
return self.vector_store._collection.count()
|
||||
|
||||
return vector_store._collection.count()
|
||||
def clear(self, **kwargs) -> None:
|
||||
"""Support deleting all nodes."""
|
||||
|
||||
ids = self.vector_store._collection.get()["ids"]
|
||||
if ids:
|
||||
self.vector_store._collection.delete(ids=ids)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue