From c43e2bed6b916096f117f72db393903694a7c090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Thu, 21 Dec 2023 10:14:25 +0800 Subject: [PATCH] update condition for DebugCode. --- metagpt/roles/ml_engineer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index b908d9ef8..9fa12b41d 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -99,7 +99,7 @@ class MLEngineer(Role): self.plan.finish_current_task() self.working_memory.clear() - if self.use_tools: + if self.use_tools or self.use_udfs: success, new_code = await self._update_data_columns() if success: task.code = task.code + "\n\n" + new_code @@ -159,7 +159,8 @@ class MLEngineer(Role): # print(context) # print("*" * 10) # breakpoint() - if counter > 0 and self.use_tools: + if counter > 0 and (self.use_tools or self.use_udfs): + logger.warning('We got a bug code, now start to debug...') code = await DebugCode().run( plan=self.plan.current_task.instruction, code=code, @@ -168,11 +169,11 @@ class MLEngineer(Role): ) logger.info(f"new code \n{code}") cause_by = DebugCode - elif not self.use_tools or self.plan.current_task.task_type == "other": + elif not self.use_tools or self.plan.current_task.task_type in ("other", "udf"): if self.use_udfs: # use user-defined function tools. from metagpt.tools.functions.libs.udf import UDFS_YAML - logger.warning("Writing code with user-defined function tools...") + logger.warning("Writing code with user-defined function tools by WriteCodeWithTools.") logger.info(f"Local user defined function as following:\ \n{json.dumps(list(UDFS_YAML.keys()), indent=2, ensure_ascii=False)}") # set task_type to `udf` @@ -211,6 +212,7 @@ class MLEngineer(Role): print(result) # make tools for successful code and long code. if success and self.make_udfs and len(code.split('\n')) > 4: + logger.info('Execute code successfully. Now start to make tools ...') await self.make_tools(code=code) self.working_memory.add( Message(content=result, role="user", cause_by=ExecutePyCode)