diff --git a/examples/exp_pool/init_exp_pool.py b/examples/exp_pool/init_exp_pool.py index 62747b8d8..c7412af22 100644 --- a/examples/exp_pool/init_exp_pool.py +++ b/examples/exp_pool/init_exp_pool.py @@ -46,8 +46,8 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None): metric=metric or Metric(score=Score(val=10, reason="Manual")), ) exp_manager = get_exp_manager() - exp_manager.config.exp_pool.enabled = True - exp_manager.config.exp_pool.enable_write = True + exp_manager.is_writable = True + exp_manager.create_exp(exp) logger.info(f"New experience created for the request `{req[:10]}`.") diff --git a/examples/exp_pool/load_exps_from_log.py b/examples/exp_pool/load_exps_from_log.py new file mode 100644 index 000000000..77eeff6dd --- /dev/null +++ b/examples/exp_pool/load_exps_from_log.py @@ -0,0 +1,85 @@ +"""Load and save experiences from the log file.""" + +import json +from pathlib import Path + +from metagpt.exp_pool import get_exp_manager +from metagpt.exp_pool.schema import LOG_NEW_EXPERIENCE_PREFIX, Experience +from metagpt.logs import logger + + +def load_exps(log_file_path: str) -> list[Experience]: + """Loads experiences from a log file. + + Args: + log_file_path (str): The path to the log file. + + Returns: + list[Experience]: A list of Experience objects loaded from the log file. + """ + + if not Path(log_file_path).exists(): + logger.warning(f"`load_exps` called with a non-existent log file path: {log_file_path}") + return + + exps = [] + with open(log_file_path, "r") as log_file: + for line in log_file: + if LOG_NEW_EXPERIENCE_PREFIX in line: + json_str = line.split(LOG_NEW_EXPERIENCE_PREFIX, 1)[1].strip() + exp_data = json.loads(json_str) + + exp = Experience(**exp_data) + exps.append(exp) + + logger.info(f"Loaded {len(exps)} experiences from log file: {log_file_path}") + + return exps + + +def save_exps(exps: list[Experience]): + """Saves a list of experiences to the experience pool. + + Args: + exps (list[Experience]): The list of experiences to save. + """ + + if not exps: + logger.warning("`save_exps` called with an empty list of experiences.") + return + + manager = get_exp_manager() + manager.is_writable = True + + manager.create_exps(exps) + logger.info(f"Saved {len(exps)} experiences.") + + +def get_log_file_path() -> str: + """Retrieves the path to the log file. + + Returns: + str: The path to the log file. + + Raises: + ValueError: If the log file path cannot be found. + """ + + handlers = logger._core.handlers + + for handler in handlers.values(): + if "log" in handler._name: + return handler._name[1:-1] + + raise ValueError("Log file not found") + + +def main(): + log_file_path = get_log_file_path() + + exps = load_exps(log_file_path) + save_exps(exps) + + +if __name__ == "__main__": + main() diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index 9b2cf3474..d49c13e95 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -10,7 +10,13 @@ from metagpt.config2 import Config from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge -from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score +from metagpt.exp_pool.schema import ( + LOG_NEW_EXPERIENCE_PREFIX, + Experience, + Metric, + QueryType, + Score, +) from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer from metagpt.logs import logger @@ -173,6 +179,7 @@ class ExpCacheHandler(BaseModel): exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score)) self.exp_manager.create_exp(exp) + self._log_exp(exp) @staticmethod def choose_wrapper(func, wrapped_func): @@ -215,3 +222,8 @@ class ExpCacheHandler(BaseModel): return await self.func(*self.args, **self.kwargs) return self.func(*self.args, **self.kwargs) + + def _log_exp(self, exp: Experience): + log_entry = exp.model_dump_json(include={"uuid", "req", "resp", "tag"}) + + logger.debug(f"{LOG_NEW_EXPERIENCE_PREFIX}{log_entry}") diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 5fbac4013..38772239b 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -31,7 +31,7 @@ class ExperienceManager(BaseModel): _storage: Any = None @property - def storage(self): + def storage(self) -> "SimpleEngine": if self._storage is None: logger.info(f"exp_pool config: {self.config.exp_pool}") @@ -44,13 +44,21 @@ class ExperienceManager(BaseModel): self._storage = value @property - def _is_readable(self) -> bool: + def is_readable(self) -> bool: return self.config.exp_pool.enabled and self.config.exp_pool.enable_read + @is_readable.setter + def is_readable(self, value: bool): + self.config.exp_pool.enabled = self.config.exp_pool.enable_read = value + @property - def _is_writable(self) -> bool: + def is_writable(self) -> bool: return self.config.exp_pool.enabled and self.config.exp_pool.enable_write + @is_writable.setter + def is_writable(self, value: bool): + self.config.exp_pool.enabled = self.config.exp_pool.enable_write = value + @handle_exception def create_exp(self, exp: Experience): """Adds an experience to the storage if writing is enabled. @@ -59,10 +67,19 @@ class ExperienceManager(BaseModel): exp (Experience): The experience to add. """ - if not self._is_writable: + self.create_exps([exp]) + + @handle_exception + def create_exps(self, exps: list[Experience]): + """Adds multiple experiences to the storage if writing is enabled. + + Args: + exps (list[Experience]): A list of experiences to add. + """ + if not self.is_writable: return - self.storage.add_objs([exp]) + self.storage.add_objs(exps) self.storage.persist(self.config.exp_pool.persist_path) @handle_exception(default_return=[]) @@ -78,7 +95,7 @@ class ExperienceManager(BaseModel): list[Experience]: A list of experiences that match the args. """ - if not self._is_readable: + if not self.is_readable: return [] nodes = await self.storage.aretrieve(req) @@ -97,7 +114,7 @@ class ExperienceManager(BaseModel): def delete_all_exps(self): """Delete the all experiences.""" - if not self._is_writable: + if not self.is_writable: return self.storage.clear(persist_dir=self.config.exp_pool.persist_path) @@ -210,7 +227,7 @@ class ExperienceManager(BaseModel): _exp_manager = None -def get_exp_manager(): +def get_exp_manager() -> ExperienceManager: global _exp_manager if _exp_manager is None: _exp_manager = ExperienceManager() diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py index a45910f0d..fea48a7f7 100644 --- a/metagpt/exp_pool/schema.py +++ b/metagpt/exp_pool/schema.py @@ -2,6 +2,7 @@ import time from enum import Enum from typing import Optional +from uuid import UUID, uuid4 from pydantic import BaseModel, Field @@ -9,6 +10,8 @@ MAX_SCORE = 10 DEFAULT_SIMILARITY_TOP_K = 2 +LOG_NEW_EXPERIENCE_PREFIX = "New experience: " + class QueryType(str, Enum): """Type of query experiences.""" @@ -67,6 +70,7 @@ class Experience(BaseModel): tag: str = Field(default="", description="Tagging experience.") traj: Optional[Trajectory] = Field(default=None, description="Trajectory.") timestamp: Optional[float] = Field(default_factory=time.time) + uuid: Optional[UUID] = Field(default_factory=uuid4) def rag_key(self): return self.req