From 9d327081baca2203f70c92fea9c3cef1742573ed Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 23 Aug 2024 10:14:12 +0800 Subject: [PATCH 1/7] add collection_name in exp_pool config --- config/config2.example.yaml | 2 ++ metagpt/configs/exp_pool_config.py | 1 + metagpt/exp_pool/manager.py | 9 ++------- metagpt/exp_pool/schema.py | 1 - metagpt/schema.py | 4 ++-- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/config/config2.example.yaml b/config/config2.example.yaml index a7214f662..2a0ebcc47 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -81,6 +81,8 @@ exp_pool: persist_path: .chroma_exp_data # The directory. retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding. use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result. + collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb. + azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index 8d33b25aa..a4a2d5d41 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -22,3 +22,4 @@ class ExperiencePoolConfig(YamlModel): default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool." ) use_llm_ranker: bool = Field(default=True, description="Use LLM Reranker to get better result.") + collection_name: str = Field(default="experience_pool", description="The collection name in chromadb") diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index e38906d90..2d50b052f 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -7,12 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field from metagpt.config2 import Config from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType -from metagpt.exp_pool.schema import ( - DEFAULT_COLLECTION_NAME, - DEFAULT_SIMILARITY_TOP_K, - Experience, - QueryType, -) +from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, Experience, QueryType from metagpt.logs import logger from metagpt.utils.exceptions import handle_exception @@ -166,7 +161,7 @@ class ExperienceManager(BaseModel): retriever_configs = [ ChromaRetrieverConfig( persist_path=self.config.exp_pool.persist_path, - collection_name=DEFAULT_COLLECTION_NAME, + collection_name=self.config.exp_pool.collection_name, similarity_top_k=DEFAULT_SIMILARITY_TOP_K, ) ] diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py index b119e5850..a45910f0d 100644 --- a/metagpt/exp_pool/schema.py +++ b/metagpt/exp_pool/schema.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Field MAX_SCORE = 10 -DEFAULT_COLLECTION_NAME = "experience_pool" DEFAULT_SIMILARITY_TOP_K = 2 diff --git a/metagpt/schema.py b/metagpt/schema.py index 8ef7dd0bb..5f9a5667f 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -148,7 +148,7 @@ class SerializationMixin(BaseModel, extra="forbid"): serialized_data = self.model_dump() write_json_file(file_path, serialized_data) - logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") + logger.debug(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") return file_path @@ -171,7 +171,7 @@ class SerializationMixin(BaseModel, extra="forbid"): data: dict = read_json_file(file_path) model = cls(**data) - logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}") + logger.debug(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}") return model From d3199604a258bbdd888a272905ee005b82a41c05 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 26 Aug 2024 19:35:00 +0800 Subject: [PATCH 2/7] exp_pool add delete_all_exps --- metagpt/exp_pool/manager.py | 19 +++++++++++++++++-- metagpt/rag/engines/simple.py | 12 +++++++++++- metagpt/rag/retrievers/base.py | 16 +++++++++++++++- metagpt/rag/retrievers/bm25_retriever.py | 16 ++++++++++++++++ metagpt/rag/retrievers/chroma_retriever.py | 13 +++++++++++-- 5 files changed, 70 insertions(+), 6 deletions(-) diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 2d50b052f..f6b38f9e7 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -47,7 +47,7 @@ class ExperienceManager(BaseModel): exp (Experience): The experience to add. """ - if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_write: + if not self._is_writable(): return self.storage.add_objs([exp]) @@ -66,7 +66,7 @@ class ExperienceManager(BaseModel): list[Experience]: A list of experiences that match the args. """ - if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_read: + if not self._is_readable(): return [] nodes = await self.storage.aretrieve(req) @@ -86,6 +86,21 @@ class ExperienceManager(BaseModel): return self.storage.count() + @handle_exception + def delete_all_exps(self): + """Delete the all experiences.""" + + if not self._is_writable(): + return + + self.storage.clear(persist_dir=self.config.exp_pool.persist_path) + + def _is_readable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_read + + def _is_writable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_write + def _resolve_storage(self) -> "SimpleEngine": """Selects the appropriate storage creation method based on the configured retrieval type.""" diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index be4c3daf5..e48decdab 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -38,6 +38,7 @@ from metagpt.rag.factories import ( ) from metagpt.rag.interface import NoEmbedding, RAGObject from metagpt.rag.retrievers.base import ( + DeletableRAGRetriever, ModifiableRAGRetriever, PersistableRAGRetriever, QueryableRAGRetriever, @@ -218,7 +219,13 @@ class SimpleEngine(RetrieverQueryEngine): """Count.""" self._ensure_retriever_queryable() - return self._retriever.query_total_count() + return self.retriever.query_total_count() + + def clear(self, **kwargs): + """Clear.""" + self._ensure_retriever_deletable() + + return self.retriever.clear(**kwargs) @staticmethod def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]: @@ -277,6 +284,9 @@ class SimpleEngine(RetrieverQueryEngine): def _ensure_retriever_queryable(self): self._ensure_retriever_of_type(QueryableRAGRetriever) + def _ensure_retriever_deletable(self): + self._ensure_retriever_of_type(DeletableRAGRetriever) + def _ensure_retriever_of_type(self, required_type: BaseRetriever): """Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever. diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index 5bd04adca..69475d6ea 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -58,4 +58,18 @@ class QueryableRAGRetriever(RAGRetriever): @abstractmethod def query_total_count(self) -> int: - """To support querying total count, must implement this func""" + """To support querying total count, must implement this func.""" + + +class DeletableRAGRetriever(RAGRetriever): + """Support deleting all nodes.""" + + @classmethod + def __subclasshook__(cls, C): + if cls is DeletableRAGRetriever: + return check_methods(C, "clear") + return NotImplemented + + @abstractmethod + def clear(self, **kwargs) -> int: + """To support deleting all nodes, must implement this func.""" diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index ace1bb86c..74cba5124 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -1,4 +1,5 @@ """BM25 retriever.""" +from pathlib import Path from typing import Callable, Optional from llama_index.core import VectorStoreIndex @@ -52,3 +53,18 @@ class DynamicBM25Retriever(BM25Retriever): """Support query total count.""" return len(self._nodes) + + def clear(self, **kwargs) -> None: + """Support deleting all nodes.""" + self._delete_json_files(kwargs.get("persist_dir")) + self._nodes = [] + + @staticmethod + def _delete_json_files(directory: str): + """Delete all JSON files in the specified directory.""" + + if not directory: + return + + for file in Path(directory).glob("*.json"): + file.unlink() diff --git a/metagpt/rag/retrievers/chroma_retriever.py b/metagpt/rag/retrievers/chroma_retriever.py index 6c466e49f..4d3d4469e 100644 --- a/metagpt/rag/retrievers/chroma_retriever.py +++ b/metagpt/rag/retrievers/chroma_retriever.py @@ -8,6 +8,10 @@ from llama_index.vector_stores.chroma import ChromaVectorStore class ChromaRetriever(VectorIndexRetriever): """Chroma retriever.""" + @property + def vector_store(self) -> ChromaVectorStore: + return self._vector_store + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """Support add nodes.""" self._index.insert_nodes(nodes, **kwargs) @@ -20,6 +24,11 @@ class ChromaRetriever(VectorIndexRetriever): def query_total_count(self) -> int: """Support query total count.""" - vector_store: ChromaVectorStore = self._vector_store + return self.vector_store._collection.count() - return vector_store._collection.count() + def clear(self, **kwargs) -> None: + """Support deleting all nodes.""" + + ids = self.vector_store._collection.get()["ids"] + if ids: + self.vector_store._collection.delete(ids=ids) From 4846f60d6011206c49dc460524b21d6abbdfb174 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 27 Aug 2024 10:37:34 +0800 Subject: [PATCH 3/7] exp_pool storage --- metagpt/exp_pool/manager.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index f6b38f9e7..5fbac4013 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -39,6 +39,18 @@ class ExperienceManager(BaseModel): return self._storage + @storage.setter + def storage(self, value): + self._storage = value + + @property + def _is_readable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_read + + @property + def _is_writable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_write + @handle_exception def create_exp(self, exp: Experience): """Adds an experience to the storage if writing is enabled. @@ -47,7 +59,7 @@ class ExperienceManager(BaseModel): exp (Experience): The experience to add. """ - if not self._is_writable(): + if not self._is_writable: return self.storage.add_objs([exp]) @@ -66,7 +78,7 @@ class ExperienceManager(BaseModel): list[Experience]: A list of experiences that match the args. """ - if not self._is_readable(): + if not self._is_readable: return [] nodes = await self.storage.aretrieve(req) @@ -81,25 +93,19 @@ class ExperienceManager(BaseModel): return exps - def get_exps_count(self) -> int: - """Get the total number of experiences.""" - - return self.storage.count() - @handle_exception def delete_all_exps(self): """Delete the all experiences.""" - if not self._is_writable(): + if not self._is_writable: return self.storage.clear(persist_dir=self.config.exp_pool.persist_path) - def _is_readable(self) -> bool: - return self.config.exp_pool.enabled and self.config.exp_pool.enable_read + def get_exps_count(self) -> int: + """Get the total number of experiences.""" - def _is_writable(self) -> bool: - return self.config.exp_pool.enabled and self.config.exp_pool.enable_write + return self.storage.count() def _resolve_storage(self) -> "SimpleEngine": """Selects the appropriate storage creation method based on the configured retrieval type.""" From 5ff3e0de7d1b652c7a6d0a3572ab7ac68fa639f1 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 27 Aug 2024 15:57:11 +0800 Subject: [PATCH 4/7] update comment --- metagpt/schema.py | 2 +- metagpt/utils/common.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/metagpt/schema.py b/metagpt/schema.py index 5f9a5667f..201ff4357 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -147,7 +147,7 @@ class SerializationMixin(BaseModel, extra="forbid"): serialized_data = self.model_dump() - write_json_file(file_path, serialized_data) + write_json_file(file_path, serialized_data, use_fallback=True) logger.debug(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") return file_path diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 2b2a209be..42a872c76 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -581,20 +581,20 @@ def read_json_file(json_file: str, encoding: str = "utf-8") -> list[Any]: def handle_unknown_serialization(x: Any) -> str: - """For `to_jsonable_python` debug, unknown values will be logged instead of raising an exception.""" + """For `to_jsonable_python` debug, get more detail about the x.""" if inspect.ismethod(x): - logger.error(f"Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}") + tip = f"Cannot serialize method '{x.__func__.__name__}' of class '{x.__self__.__class__.__name__}'" elif inspect.isfunction(x): - logger.error(f"Function: {x.__name__}") + tip = f"Cannot serialize function '{x.__name__}'" elif hasattr(x, "__class__"): - logger.error(f"Instance of: {x.__class__.__name__}") + tip = f"Cannot serialize instance of '{x.__class__.__name__}'" elif hasattr(x, "__name__"): - logger.error(f"Class or module: {x.__name__}") + tip = f"Cannot serialize class or module '{x.__name__}'" else: - logger.error(f"Unknown type: {type(x)}") + tip = f"Cannot serialize object of type '{type(x).__name__}'" - return f"" + raise TypeError(tip) def write_json_file(json_file: str, data: Any, encoding: str = "utf-8", indent: int = 4, use_fallback: bool = False): From 49505d37eb01cdd7e7935fbaec886f03c51b8c28 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 27 Aug 2024 20:02:48 +0800 Subject: [PATCH 5/7] log and load exp --- examples/exp_pool/init_exp_pool.py | 4 +- examples/exp_pool/load_exps_from_log.py | 85 +++++++++++++++++++++++++ metagpt/exp_pool/decorator.py | 14 +++- metagpt/exp_pool/manager.py | 33 +++++++--- metagpt/exp_pool/schema.py | 4 ++ 5 files changed, 129 insertions(+), 11 deletions(-) create mode 100644 examples/exp_pool/load_exps_from_log.py diff --git a/examples/exp_pool/init_exp_pool.py b/examples/exp_pool/init_exp_pool.py index 62747b8d8..c7412af22 100644 --- a/examples/exp_pool/init_exp_pool.py +++ b/examples/exp_pool/init_exp_pool.py @@ -46,8 +46,8 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None): metric=metric or Metric(score=Score(val=10, reason="Manual")), ) exp_manager = get_exp_manager() - exp_manager.config.exp_pool.enabled = True - exp_manager.config.exp_pool.enable_write = True + exp_manager.is_writable = True + exp_manager.create_exp(exp) logger.info(f"New experience created for the request `{req[:10]}`.") diff --git a/examples/exp_pool/load_exps_from_log.py b/examples/exp_pool/load_exps_from_log.py new file mode 100644 index 000000000..77eeff6dd --- /dev/null +++ b/examples/exp_pool/load_exps_from_log.py @@ -0,0 +1,85 @@ +"""Load and save experiences from the log file.""" + +import json +from pathlib import Path + +from metagpt.exp_pool import get_exp_manager +from metagpt.exp_pool.schema import LOG_NEW_EXPERIENCE_PREFIX, Experience +from metagpt.logs import logger + + +def load_exps(log_file_path: str) -> list[Experience]: + """Loads experiences from a log file. + + Args: + log_file_path (str): The path to the log file. + + Returns: + list[Experience]: A list of Experience objects loaded from the log file. + """ + + if not Path(log_file_path).exists(): + logger.warning(f"`load_exps` called with a non-existent log file path: {log_file_path}") + return + + exps = [] + with open(log_file_path, "r") as log_file: + for line in log_file: + if LOG_NEW_EXPERIENCE_PREFIX in line: + json_str = line.split(LOG_NEW_EXPERIENCE_PREFIX, 1)[1].strip() + exp_data = json.loads(json_str) + + exp = Experience(**exp_data) + exps.append(exp) + + logger.info(f"Loaded {len(exps)} experiences from log file: {log_file_path}") + + return exps + + +def save_exps(exps: list[Experience]): + """Saves a list of experiences to the experience pool. + + Args: + exps (list[Experience]): The list of experiences to save. + """ + + if not exps: + logger.warning("`save_exps` called with an empty list of experiences.") + return + + manager = get_exp_manager() + manager.is_writable = True + + manager.create_exps(exps) + logger.info(f"Saved {len(exps)} experiences.") + + +def get_log_file_path() -> str: + """Retrieves the path to the log file. + + Returns: + str: The path to the log file. + + Raises: + ValueError: If the log file path cannot be found. + """ + + handlers = logger._core.handlers + + for handler in handlers.values(): + if "log" in handler._name: + return handler._name[1:-1] + + raise ValueError("Log file not found") + + +def main(): + log_file_path = get_log_file_path() + + exps = load_exps(log_file_path) + save_exps(exps) + + +if __name__ == "__main__": + main() diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index 9b2cf3474..d49c13e95 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -10,7 +10,13 @@ from metagpt.config2 import Config from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge -from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score +from metagpt.exp_pool.schema import ( + LOG_NEW_EXPERIENCE_PREFIX, + Experience, + Metric, + QueryType, + Score, +) from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer from metagpt.logs import logger @@ -173,6 +179,7 @@ class ExpCacheHandler(BaseModel): exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score)) self.exp_manager.create_exp(exp) + self._log_exp(exp) @staticmethod def choose_wrapper(func, wrapped_func): @@ -215,3 +222,8 @@ class ExpCacheHandler(BaseModel): return await self.func(*self.args, **self.kwargs) return self.func(*self.args, **self.kwargs) + + def _log_exp(self, exp: Experience): + log_entry = exp.model_dump_json(include={"uuid", "req", "resp", "tag"}) + + logger.debug(f"{LOG_NEW_EXPERIENCE_PREFIX}{log_entry}") diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 5fbac4013..38772239b 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -31,7 +31,7 @@ class ExperienceManager(BaseModel): _storage: Any = None @property - def storage(self): + def storage(self) -> "SimpleEngine": if self._storage is None: logger.info(f"exp_pool config: {self.config.exp_pool}") @@ -44,13 +44,21 @@ class ExperienceManager(BaseModel): self._storage = value @property - def _is_readable(self) -> bool: + def is_readable(self) -> bool: return self.config.exp_pool.enabled and self.config.exp_pool.enable_read + @is_readable.setter + def is_readable(self, value: bool): + self.config.exp_pool.enabled = self.config.exp_pool.enable_read = value + @property - def _is_writable(self) -> bool: + def is_writable(self) -> bool: return self.config.exp_pool.enabled and self.config.exp_pool.enable_write + @is_writable.setter + def is_writable(self, value: bool): + self.config.exp_pool.enabled = self.config.exp_pool.enable_write = value + @handle_exception def create_exp(self, exp: Experience): """Adds an experience to the storage if writing is enabled. @@ -59,10 +67,19 @@ class ExperienceManager(BaseModel): exp (Experience): The experience to add. """ - if not self._is_writable: + self.create_exps([exp]) + + @handle_exception + def create_exps(self, exps: list[Experience]): + """Adds multiple experiences to the storage if writing is enabled. + + Args: + exps (list[Experience]): A list of experiences to add. + """ + if not self.is_writable: return - self.storage.add_objs([exp]) + self.storage.add_objs(exps) self.storage.persist(self.config.exp_pool.persist_path) @handle_exception(default_return=[]) @@ -78,7 +95,7 @@ class ExperienceManager(BaseModel): list[Experience]: A list of experiences that match the args. """ - if not self._is_readable: + if not self.is_readable: return [] nodes = await self.storage.aretrieve(req) @@ -97,7 +114,7 @@ class ExperienceManager(BaseModel): def delete_all_exps(self): """Delete the all experiences.""" - if not self._is_writable: + if not self.is_writable: return self.storage.clear(persist_dir=self.config.exp_pool.persist_path) @@ -210,7 +227,7 @@ class ExperienceManager(BaseModel): _exp_manager = None -def get_exp_manager(): +def get_exp_manager() -> ExperienceManager: global _exp_manager if _exp_manager is None: _exp_manager = ExperienceManager() diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py index a45910f0d..fea48a7f7 100644 --- a/metagpt/exp_pool/schema.py +++ b/metagpt/exp_pool/schema.py @@ -2,6 +2,7 @@ import time from enum import Enum from typing import Optional +from uuid import UUID, uuid4 from pydantic import BaseModel, Field @@ -9,6 +10,8 @@ MAX_SCORE = 10 DEFAULT_SIMILARITY_TOP_K = 2 +LOG_NEW_EXPERIENCE_PREFIX = "New experience: " + class QueryType(str, Enum): """Type of query experiences.""" @@ -67,6 +70,7 @@ class Experience(BaseModel): tag: str = Field(default="", description="Tagging experience.") traj: Optional[Trajectory] = Field(default=None, description="Trajectory.") timestamp: Optional[float] = Field(default_factory=time.time) + uuid: Optional[UUID] = Field(default_factory=uuid4) def rag_key(self): return self.req From 37c4b5c1587419e4178cec80ae83e50b8df48693 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 28 Aug 2024 10:30:51 +0800 Subject: [PATCH 6/7] update comment --- metagpt/rag/retrievers/bm25_retriever.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 74cba5124..4891fad50 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -37,6 +37,7 @@ class DynamicBM25Retriever(BM25Retriever): def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """Support add nodes.""" + self._nodes.extend(nodes) self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] self.bm25 = BM25Okapi(self._corpus) @@ -46,6 +47,7 @@ class DynamicBM25Retriever(BM25Retriever): def persist(self, persist_dir: str, **kwargs) -> None: """Support persist.""" + if self._index: self._index.storage_context.persist(persist_dir) @@ -56,6 +58,7 @@ class DynamicBM25Retriever(BM25Retriever): def clear(self, **kwargs) -> None: """Support deleting all nodes.""" + self._delete_json_files(kwargs.get("persist_dir")) self._nodes = [] From dbf41120b0999aafdfec32cd4617bbdd3290874a Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 28 Aug 2024 11:46:05 +0800 Subject: [PATCH 7/7] update setter of is_writable and is_readable --- metagpt/exp_pool/manager.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 38772239b..35de17079 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -49,7 +49,11 @@ class ExperienceManager(BaseModel): @is_readable.setter def is_readable(self, value: bool): - self.config.exp_pool.enabled = self.config.exp_pool.enable_read = value + self.config.exp_pool.enable_read = value + + # If set to True, ensure that enabled is also True. + if value: + self.config.exp_pool.enabled = True @property def is_writable(self) -> bool: @@ -57,7 +61,11 @@ class ExperienceManager(BaseModel): @is_writable.setter def is_writable(self, value: bool): - self.config.exp_pool.enabled = self.config.exp_pool.enable_write = value + self.config.exp_pool.enable_write = value + + # If set to True, ensure that enabled is also True. + if value: + self.config.exp_pool.enabled = True @handle_exception def create_exp(self, exp: Experience):