diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index b908d9ef8..9fa12b41d 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -99,7 +99,7 @@ class MLEngineer(Role): self.plan.finish_current_task() self.working_memory.clear() - if self.use_tools: + if self.use_tools or self.use_udfs: success, new_code = await self._update_data_columns() if success: task.code = task.code + "\n\n" + new_code @@ -159,7 +159,8 @@ class MLEngineer(Role): # print(context) # print("*" * 10) # breakpoint() - if counter > 0 and self.use_tools: + if counter > 0 and (self.use_tools or self.use_udfs): + logger.warning('We got a bug code, now start to debug...') code = await DebugCode().run( plan=self.plan.current_task.instruction, code=code, @@ -168,11 +169,11 @@ class MLEngineer(Role): ) logger.info(f"new code \n{code}") cause_by = DebugCode - elif not self.use_tools or self.plan.current_task.task_type == "other": + elif not self.use_tools or self.plan.current_task.task_type in ("other", "udf"): if self.use_udfs: # use user-defined function tools. from metagpt.tools.functions.libs.udf import UDFS_YAML - logger.warning("Writing code with user-defined function tools...") + logger.warning("Writing code with user-defined function tools by WriteCodeWithTools.") logger.info(f"Local user defined function as following:\ \n{json.dumps(list(UDFS_YAML.keys()), indent=2, ensure_ascii=False)}") # set task_type to `udf` @@ -211,6 +212,7 @@ class MLEngineer(Role): print(result) # make tools for successful code and long code. if success and self.make_udfs and len(code.split('\n')) > 4: + logger.info('Execute code successfully. Now start to make tools ...') await self.make_tools(code=code) self.working_memory.add( Message(content=result, role="user", cause_by=ExecutePyCode)