diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index c2df4bb79..cafd9b968 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -168,22 +168,16 @@ class MLEngineer(Role): \n{json.dumps(list(UDFS_YAML.keys()), indent=2, ensure_ascii=False)}") # set task_type to `udf` self.plan.current_task.task_type = 'udf' - tool_context, code = await WriteCodeWithTools(schema_path=UDFS_YAML).run( - context=context, - plan=self.plan, - column_info=self.data_desc.get("column_info", ""), - ) - debug_context = tool_context - cause_by = WriteCodeWithTools + schema_path = UDFS_YAML else: schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas" - tool_context, code = await WriteCodeWithTools(schema_path=schema_path).run( - context=context, - plan=self.plan, - column_info=self.data_desc.get("column_info", ""), - ) - debug_context = tool_context - cause_by = WriteCodeWithTools + tool_context, code = await WriteCodeWithTools(schema_path=schema_path).run( + context=context, + plan=self.plan, + column_info=self.data_desc.get("column_info", ""), + ) + debug_context = tool_context + cause_by = WriteCodeWithTools self.working_memory.add( Message(content=code, role="assistant", cause_by=cause_by) ) @@ -301,6 +295,7 @@ class MLEngineer(Role): # tool_code = await make_tools.run(code_prompt) tool_code = await make_tools.run(code) _, success = await self.execute_code.run(tool_code) + make_tool_retries += 1 if make_tool_current_retry > make_tool_retries: logger.error(f"We have tried the maximum number of attempts {make_tool_retries}\ and still have not created tools for task_id {self.plan.current_task_id} successfully,\