mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
log and load exp
This commit is contained in:
parent
5ff3e0de7d
commit
49505d37eb
5 changed files with 129 additions and 11 deletions
|
|
@ -46,8 +46,8 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None):
|
|||
metric=metric or Metric(score=Score(val=10, reason="Manual")),
|
||||
)
|
||||
exp_manager = get_exp_manager()
|
||||
exp_manager.config.exp_pool.enabled = True
|
||||
exp_manager.config.exp_pool.enable_write = True
|
||||
exp_manager.is_writable = True
|
||||
|
||||
exp_manager.create_exp(exp)
|
||||
logger.info(f"New experience created for the request `{req[:10]}`.")
|
||||
|
||||
|
|
|
|||
85
examples/exp_pool/load_exps_from_log.py
Normal file
85
examples/exp_pool/load_exps_from_log.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""Load and save experiences from the log file."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.exp_pool import get_exp_manager
|
||||
from metagpt.exp_pool.schema import LOG_NEW_EXPERIENCE_PREFIX, Experience
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def load_exps(log_file_path: str) -> list[Experience]:
|
||||
"""Loads experiences from a log file.
|
||||
|
||||
Args:
|
||||
log_file_path (str): The path to the log file.
|
||||
|
||||
Returns:
|
||||
list[Experience]: A list of Experience objects loaded from the log file.
|
||||
"""
|
||||
|
||||
if not Path(log_file_path).exists():
|
||||
logger.warning(f"`load_exps` called with a non-existent log file path: {log_file_path}")
|
||||
return
|
||||
|
||||
exps = []
|
||||
with open(log_file_path, "r") as log_file:
|
||||
for line in log_file:
|
||||
if LOG_NEW_EXPERIENCE_PREFIX in line:
|
||||
json_str = line.split(LOG_NEW_EXPERIENCE_PREFIX, 1)[1].strip()
|
||||
exp_data = json.loads(json_str)
|
||||
|
||||
exp = Experience(**exp_data)
|
||||
exps.append(exp)
|
||||
|
||||
logger.info(f"Loaded {len(exps)} experiences from log file: {log_file_path}")
|
||||
|
||||
return exps
|
||||
|
||||
|
||||
def save_exps(exps: list[Experience]):
|
||||
"""Saves a list of experiences to the experience pool.
|
||||
|
||||
Args:
|
||||
exps (list[Experience]): The list of experiences to save.
|
||||
"""
|
||||
|
||||
if not exps:
|
||||
logger.warning("`save_exps` called with an empty list of experiences.")
|
||||
return
|
||||
|
||||
manager = get_exp_manager()
|
||||
manager.is_writable = True
|
||||
|
||||
manager.create_exps(exps)
|
||||
logger.info(f"Saved {len(exps)} experiences.")
|
||||
|
||||
|
||||
def get_log_file_path() -> str:
|
||||
"""Retrieves the path to the log file.
|
||||
|
||||
Returns:
|
||||
str: The path to the log file.
|
||||
|
||||
Raises:
|
||||
ValueError: If the log file path cannot be found.
|
||||
"""
|
||||
|
||||
handlers = logger._core.handlers
|
||||
|
||||
for handler in handlers.values():
|
||||
if "log" in handler._name:
|
||||
return handler._name[1:-1]
|
||||
|
||||
raise ValueError("Log file not found")
|
||||
|
||||
|
||||
def main():
|
||||
log_file_path = get_log_file_path()
|
||||
|
||||
exps = load_exps(log_file_path)
|
||||
save_exps(exps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -10,7 +10,13 @@ from metagpt.config2 import Config
|
|||
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
|
||||
from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager
|
||||
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
|
||||
from metagpt.exp_pool.schema import (
|
||||
LOG_NEW_EXPERIENCE_PREFIX,
|
||||
Experience,
|
||||
Metric,
|
||||
QueryType,
|
||||
Score,
|
||||
)
|
||||
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
|
||||
from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -173,6 +179,7 @@ class ExpCacheHandler(BaseModel):
|
|||
|
||||
exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score))
|
||||
self.exp_manager.create_exp(exp)
|
||||
self._log_exp(exp)
|
||||
|
||||
@staticmethod
|
||||
def choose_wrapper(func, wrapped_func):
|
||||
|
|
@ -215,3 +222,8 @@ class ExpCacheHandler(BaseModel):
|
|||
return await self.func(*self.args, **self.kwargs)
|
||||
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
def _log_exp(self, exp: Experience):
|
||||
log_entry = exp.model_dump_json(include={"uuid", "req", "resp", "tag"})
|
||||
|
||||
logger.debug(f"{LOG_NEW_EXPERIENCE_PREFIX}{log_entry}")
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class ExperienceManager(BaseModel):
|
|||
_storage: Any = None
|
||||
|
||||
@property
|
||||
def storage(self):
|
||||
def storage(self) -> "SimpleEngine":
|
||||
if self._storage is None:
|
||||
logger.info(f"exp_pool config: {self.config.exp_pool}")
|
||||
|
||||
|
|
@ -44,13 +44,21 @@ class ExperienceManager(BaseModel):
|
|||
self._storage = value
|
||||
|
||||
@property
|
||||
def _is_readable(self) -> bool:
|
||||
def is_readable(self) -> bool:
|
||||
return self.config.exp_pool.enabled and self.config.exp_pool.enable_read
|
||||
|
||||
@is_readable.setter
|
||||
def is_readable(self, value: bool):
|
||||
self.config.exp_pool.enabled = self.config.exp_pool.enable_read = value
|
||||
|
||||
@property
|
||||
def _is_writable(self) -> bool:
|
||||
def is_writable(self) -> bool:
|
||||
return self.config.exp_pool.enabled and self.config.exp_pool.enable_write
|
||||
|
||||
@is_writable.setter
|
||||
def is_writable(self, value: bool):
|
||||
self.config.exp_pool.enabled = self.config.exp_pool.enable_write = value
|
||||
|
||||
@handle_exception
|
||||
def create_exp(self, exp: Experience):
|
||||
"""Adds an experience to the storage if writing is enabled.
|
||||
|
|
@ -59,10 +67,19 @@ class ExperienceManager(BaseModel):
|
|||
exp (Experience): The experience to add.
|
||||
"""
|
||||
|
||||
if not self._is_writable:
|
||||
self.create_exps([exp])
|
||||
|
||||
@handle_exception
|
||||
def create_exps(self, exps: list[Experience]):
|
||||
"""Adds multiple experiences to the storage if writing is enabled.
|
||||
|
||||
Args:
|
||||
exps (list[Experience]): A list of experiences to add.
|
||||
"""
|
||||
if not self.is_writable:
|
||||
return
|
||||
|
||||
self.storage.add_objs([exp])
|
||||
self.storage.add_objs(exps)
|
||||
self.storage.persist(self.config.exp_pool.persist_path)
|
||||
|
||||
@handle_exception(default_return=[])
|
||||
|
|
@ -78,7 +95,7 @@ class ExperienceManager(BaseModel):
|
|||
list[Experience]: A list of experiences that match the args.
|
||||
"""
|
||||
|
||||
if not self._is_readable:
|
||||
if not self.is_readable:
|
||||
return []
|
||||
|
||||
nodes = await self.storage.aretrieve(req)
|
||||
|
|
@ -97,7 +114,7 @@ class ExperienceManager(BaseModel):
|
|||
def delete_all_exps(self):
|
||||
"""Delete the all experiences."""
|
||||
|
||||
if not self._is_writable:
|
||||
if not self.is_writable:
|
||||
return
|
||||
|
||||
self.storage.clear(persist_dir=self.config.exp_pool.persist_path)
|
||||
|
|
@ -210,7 +227,7 @@ class ExperienceManager(BaseModel):
|
|||
_exp_manager = None
|
||||
|
||||
|
||||
def get_exp_manager():
|
||||
def get_exp_manager() -> ExperienceManager:
|
||||
global _exp_manager
|
||||
if _exp_manager is None:
|
||||
_exp_manager = ExperienceManager()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -9,6 +10,8 @@ MAX_SCORE = 10
|
|||
|
||||
DEFAULT_SIMILARITY_TOP_K = 2
|
||||
|
||||
LOG_NEW_EXPERIENCE_PREFIX = "New experience: "
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""Type of query experiences."""
|
||||
|
|
@ -67,6 +70,7 @@ class Experience(BaseModel):
|
|||
tag: str = Field(default="", description="Tagging experience.")
|
||||
traj: Optional[Trajectory] = Field(default=None, description="Trajectory.")
|
||||
timestamp: Optional[float] = Field(default_factory=time.time)
|
||||
uuid: Optional[UUID] = Field(default_factory=uuid4)
|
||||
|
||||
def rag_key(self):
|
||||
return self.req
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue