log and load exp

This commit is contained in:
seehi 2024-08-27 20:02:48 +08:00
parent 5ff3e0de7d
commit 49505d37eb
5 changed files with 129 additions and 11 deletions

View file

@ -46,8 +46,8 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None):
metric=metric or Metric(score=Score(val=10, reason="Manual")),
)
exp_manager = get_exp_manager()
exp_manager.config.exp_pool.enabled = True
exp_manager.config.exp_pool.enable_write = True
exp_manager.is_writable = True
exp_manager.create_exp(exp)
logger.info(f"New experience created for the request `{req[:10]}`.")

View file

@ -0,0 +1,85 @@
"""Load and save experiences from the log file."""
import json
from pathlib import Path
from metagpt.exp_pool import get_exp_manager
from metagpt.exp_pool.schema import LOG_NEW_EXPERIENCE_PREFIX, Experience
from metagpt.logs import logger
def load_exps(log_file_path: str) -> list[Experience]:
"""Loads experiences from a log file.
Args:
log_file_path (str): The path to the log file.
Returns:
list[Experience]: A list of Experience objects loaded from the log file.
"""
if not Path(log_file_path).exists():
logger.warning(f"`load_exps` called with a non-existent log file path: {log_file_path}")
return
exps = []
with open(log_file_path, "r") as log_file:
for line in log_file:
if LOG_NEW_EXPERIENCE_PREFIX in line:
json_str = line.split(LOG_NEW_EXPERIENCE_PREFIX, 1)[1].strip()
exp_data = json.loads(json_str)
exp = Experience(**exp_data)
exps.append(exp)
logger.info(f"Loaded {len(exps)} experiences from log file: {log_file_path}")
return exps
def save_exps(exps: list[Experience]):
"""Saves a list of experiences to the experience pool.
Args:
exps (list[Experience]): The list of experiences to save.
"""
if not exps:
logger.warning("`save_exps` called with an empty list of experiences.")
return
manager = get_exp_manager()
manager.is_writable = True
manager.create_exps(exps)
logger.info(f"Saved {len(exps)} experiences.")
def get_log_file_path() -> str:
"""Retrieves the path to the log file.
Returns:
str: The path to the log file.
Raises:
ValueError: If the log file path cannot be found.
"""
handlers = logger._core.handlers
for handler in handlers.values():
if "log" in handler._name:
return handler._name[1:-1]
raise ValueError("Log file not found")
def main():
log_file_path = get_log_file_path()
exps = load_exps(log_file_path)
save_exps(exps)
if __name__ == "__main__":
main()

View file

@ -10,7 +10,13 @@ from metagpt.config2 import Config
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
from metagpt.exp_pool.schema import (
LOG_NEW_EXPERIENCE_PREFIX,
Experience,
Metric,
QueryType,
Score,
)
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer
from metagpt.logs import logger
@ -173,6 +179,7 @@ class ExpCacheHandler(BaseModel):
exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score))
self.exp_manager.create_exp(exp)
self._log_exp(exp)
@staticmethod
def choose_wrapper(func, wrapped_func):
@ -215,3 +222,8 @@ class ExpCacheHandler(BaseModel):
return await self.func(*self.args, **self.kwargs)
return self.func(*self.args, **self.kwargs)
def _log_exp(self, exp: Experience):
log_entry = exp.model_dump_json(include={"uuid", "req", "resp", "tag"})
logger.debug(f"{LOG_NEW_EXPERIENCE_PREFIX}{log_entry}")

View file

@ -31,7 +31,7 @@ class ExperienceManager(BaseModel):
_storage: Any = None
@property
def storage(self):
def storage(self) -> "SimpleEngine":
if self._storage is None:
logger.info(f"exp_pool config: {self.config.exp_pool}")
@ -44,13 +44,21 @@ class ExperienceManager(BaseModel):
self._storage = value
@property
def _is_readable(self) -> bool:
def is_readable(self) -> bool:
return self.config.exp_pool.enabled and self.config.exp_pool.enable_read
@is_readable.setter
def is_readable(self, value: bool):
self.config.exp_pool.enabled = self.config.exp_pool.enable_read = value
@property
def _is_writable(self) -> bool:
def is_writable(self) -> bool:
return self.config.exp_pool.enabled and self.config.exp_pool.enable_write
@is_writable.setter
def is_writable(self, value: bool):
self.config.exp_pool.enabled = self.config.exp_pool.enable_write = value
@handle_exception
def create_exp(self, exp: Experience):
"""Adds an experience to the storage if writing is enabled.
@ -59,10 +67,19 @@ class ExperienceManager(BaseModel):
exp (Experience): The experience to add.
"""
if not self._is_writable:
self.create_exps([exp])
@handle_exception
def create_exps(self, exps: list[Experience]):
"""Adds multiple experiences to the storage if writing is enabled.
Args:
exps (list[Experience]): A list of experiences to add.
"""
if not self.is_writable:
return
self.storage.add_objs([exp])
self.storage.add_objs(exps)
self.storage.persist(self.config.exp_pool.persist_path)
@handle_exception(default_return=[])
@ -78,7 +95,7 @@ class ExperienceManager(BaseModel):
list[Experience]: A list of experiences that match the args.
"""
if not self._is_readable:
if not self.is_readable:
return []
nodes = await self.storage.aretrieve(req)
@ -97,7 +114,7 @@ class ExperienceManager(BaseModel):
def delete_all_exps(self):
"""Delete the all experiences."""
if not self._is_writable:
if not self.is_writable:
return
self.storage.clear(persist_dir=self.config.exp_pool.persist_path)
@ -210,7 +227,7 @@ class ExperienceManager(BaseModel):
_exp_manager = None
def get_exp_manager():
def get_exp_manager() -> ExperienceManager:
global _exp_manager
if _exp_manager is None:
_exp_manager = ExperienceManager()

View file

@ -2,6 +2,7 @@
import time
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
@ -9,6 +10,8 @@ MAX_SCORE = 10
DEFAULT_SIMILARITY_TOP_K = 2
LOG_NEW_EXPERIENCE_PREFIX = "New experience: "
class QueryType(str, Enum):
"""Type of query experiences."""
@ -67,6 +70,7 @@ class Experience(BaseModel):
tag: str = Field(default="", description="Tagging experience.")
traj: Optional[Trajectory] = Field(default=None, description="Trajectory.")
timestamp: Optional[float] = Field(default_factory=time.time)
uuid: Optional[UUID] = Field(default_factory=uuid4)
def rag_key(self):
return self.req