diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 75c403226..96e21c8c8 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -180,10 +180,12 @@ class MLEngineer(Role): debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]] cause_by = WriteCodeByGenerate - if self.make_udfs: + if self.make_udfs and len(code.split('\n')) > 2: # make and save user-defined function tools. make_tools = MakeTools() - tool_code = await make_tools.run(code) + code_prompt = f"The following code is about {self.plan.current_task.instruction},\ + convert it to be a General Function, {code}" + tool_code = await make_tools.run(code_prompt) make_tools.save(tool_code) else: logger.info("Write code with tools")