diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index e7fe38ff4..b039c61e7 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -9,7 +9,7 @@ from metagpt.schema import Message, Plan from metagpt.memory import Memory from metagpt.logs import logger from metagpt.actions.write_plan import WritePlan, update_plan_from_rsp, precheck_update_plan_from_rsp -from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools +from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools, MakeTools from metagpt.actions.ml_da_action import AskReview, SummarizeAnalysis, Reflect, ReviewConst from metagpt.actions.execute_code import ExecutePyCode from metagpt.roles.kaggle_manager import DownloadData, SubmitResult @@ -126,6 +126,10 @@ class MLEngineer(Role): context=context, plan=self.plan, code_steps=code_steps, temperature=0.0 ) cause_by = WriteCodeByGenerate + # make and save tools. + make_tools = MakeTools() + tool_code = await make_tools.run(code) + make_tools.save(tool_code) else: code = await WriteCodeWithTools().run( context=context, plan=self.plan, code_steps=code_steps, data_desc=""