From 82ccdde687ff55734bfe16353d1511ea34c3f4ed Mon Sep 17 00:00:00 2001 From: stellahsr Date: Thu, 14 Dec 2023 17:18:35 +0800 Subject: [PATCH] use tools --- metagpt/roles/ml_engineer.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index b38c752a4..bd46ae79a 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -49,8 +49,8 @@ class MLEngineer(Role): self._watch([DownloadData, SubmitResult]) self.plan = Plan(goal=goal) - self.use_tools = False - self.use_code_steps = False + self.use_tools = True + self.use_code_steps = True self.execute_code = ExecutePyCode() self.auto_run = auto_run self.data_desc = {} @@ -101,8 +101,8 @@ class MLEngineer(Role): if self.use_tools: success, new_code = await self._update_data_columns() - if success: - task.code = task.code + "\n\n" + new_code + if success: + task.code = task.code + "\n\n" + new_code confirmed_and_more = (ReviewConst.CONTINUE_WORD[0] in review.lower() and review.lower() not in ReviewConst.CONTINUE_WORD[ @@ -245,9 +245,7 @@ class MLEngineer(Role): async def _reflect(self): context = self.get_memories() context = "\n".join([str(msg) for msg in context]) - # print("*" * 10) - # print(context) - # print("*" * 10) + reflection = await Reflect().run(context=context) self.working_memory.add(Message(content=reflection, role="assistant")) self.working_memory.add(Message(content=Reflect.REWRITE_PLAN_INSTRUCTION, role="user")) @@ -296,7 +294,7 @@ if __name__ == "__main__": # requirement = f"This is a customers financial dataset. Your goal is to predict which customers will make a specific transaction in the future. The target column is target. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report F1 Score on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv' ." save_dir = "" - save_dir = DATA_PATH / "save" / "2023-12-14_15-11-40" + # save_dir = DATA_PATH / "save" / "2023-12-14_16-58-03" def load_history(save_dir: str = save_dir): @@ -328,13 +326,14 @@ if __name__ == "__main__": Returns: Path: The path to the saved history directory. """ - save_path = Path(save_dir) if save_dir else DATA_PATH / "save" / datetime.now().strftime( + # save_path = Path(save_dir) if save_dir else DATA_PATH / "save" / datetime.now().strftime( + # '%Y-%m-%d_%H-%M-%S') + save_path = DATA_PATH / "save" / datetime.now().strftime( '%Y-%m-%d_%H-%M-%S') - # overwrite + # overwrite exist trajectory save_path.mkdir(parents=True, exist_ok=True) plan = role.plan.dict() - logger.info(f"Plan is {plan}") with open(save_path / "plan.json", "w", encoding="utf-8") as plan_file: json.dump(plan, plan_file, indent=4, ensure_ascii=False) @@ -361,8 +360,7 @@ if __name__ == "__main__": role = MLEngineer(goal=requirement, auto_run=auto_run) role.plan = Plan(**plan) role.execute_code = ExecutePyCode(nb) - import pdb; - pdb.set_trace() + else: logger.info("Run from scratch") role = MLEngineer(goal=requirement, auto_run=auto_run)