diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index f7538ae2e..16ffe69db 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -3,8 +3,7 @@ import json from datetime import datetime import fire -import nbformat -from pathlib import Path + from metagpt.actions import Action from metagpt.actions.debug_code import DebugCode @@ -27,7 +26,7 @@ from metagpt.roles.kaggle_manager import DownloadData, SubmitResult from metagpt.schema import Message, Plan from metagpt.utils.common import remove_comments, create_func_config from metagpt.utils.save_code import save_code_file - +from metagpt.utils.recovery_util import save_history, load_history class UpdateDataColumns(Action): async def run(self, plan: Plan = None) -> dict: @@ -297,49 +296,8 @@ if __name__ == "__main__": save_dir = "" # save_dir = DATA_PATH / "output" / "2023-12-14_20-40-34" - def load_history(save_dir: str = save_dir): - """ - Load history from the specified save directory. - - Args: - save_dir (str): The directory from which to load the history. - - Returns: - Tuple: A tuple containing the loaded plan and notebook. - """ - - plan_path = Path(save_dir) / "plan.json" - nb_path = Path(save_dir) / "history_nb" / "code.ipynb" - plan = json.load(open(plan_path, "r", encoding="utf-8")) - nb = nbformat.read(open(nb_path, "r", encoding="utf-8"), as_version=nbformat.NO_CONVERT) - return plan, nb - async def save_history(role: Role = MLEngineer, save_dir: str = save_dir): - """ - Save history to the specified directory. - - Args: - role (Role): The role containing the plan and execute_code attributes. - save_dir (str): The directory to save the history. - - Returns: - Path: The path to the saved history directory. - """ - record_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - save_path = DATA_PATH / "output" / f"{record_time}" - - # overwrite exist trajectory - save_path.mkdir(parents=True, exist_ok=True) - - plan = role.plan.dict() - - with open(save_path / "plan.json", "w", encoding="utf-8") as plan_file: - json.dump(plan, plan_file, indent=4, ensure_ascii=False) - - save_code_file(name=Path(record_time) / "history_nb", code_context=role.execute_code.nb, file_format="ipynb") - return save_path - async def main(requirement: str = requirement, auto_run: bool = True, save_dir: str = save_dir): """ @@ -368,7 +326,7 @@ if __name__ == "__main__": await role.run(requirement) except Exception as e: - save_path = await save_history(role, save_dir) + save_path = save_history(role, save_dir) logger.exception(f"An error occurred: {e}, save trajectory here: {save_path}") diff --git a/metagpt/utils/recovery_util.py b/metagpt/utils/recovery_util.py new file mode 100644 index 000000000..ef4f0aca7 --- /dev/null +++ b/metagpt/utils/recovery_util.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# @Date : 12/20/2023 11:07 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import nbformat +from pathlib import Path +import json +from datetime import datetime + +from metagpt.roles.role import Role +from metagpt.roles.ml_engineer import MLEngineer +from metagpt.const import DATA_PATH +from metagpt.utils.save_code import save_code_file + +def load_history(save_dir: str = ""): + """ + Load history from the specified save directory. + + Args: + save_dir (str): The directory from which to load the history. + + Returns: + Tuple: A tuple containing the loaded plan and notebook. + """ + + plan_path = Path(save_dir) / "plan.json" + nb_path = Path(save_dir) / "history_nb" / "code.ipynb" + plan = json.load(open(plan_path, "r", encoding="utf-8")) + nb = nbformat.read(open(nb_path, "r", encoding="utf-8"), as_version=nbformat.NO_CONVERT) + return plan, nb + + +def save_history(role: Role = MLEngineer, save_dir: str = ""): + """ + Save history to the specified directory. + + Args: + role (Role): The role containing the plan and execute_code attributes. + save_dir (str): The directory to save the history. + + Returns: + Path: The path to the saved history directory. + """ + record_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + save_path = DATA_PATH / "output" / f"{record_time}" + + # overwrite exist trajectory + save_path.mkdir(parents=True, exist_ok=True) + + plan = role.plan.dict() + + with open(save_path / "plan.json", "w", encoding="utf-8") as plan_file: + json.dump(plan, plan_file, indent=4, ensure_ascii=False) + + save_code_file(name=Path(record_time) / "history_nb", code_context=role.execute_code.nb, file_format="ipynb") + return save_path \ No newline at end of file