diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 3c1853fd5..052b99ad5 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -21,6 +21,7 @@ from metagpt.prompts.ml_engineer import ( PRINT_DATA_COLUMNS ) from metagpt.roles import Role +from metagpt.roles.role import RoleContext from metagpt.roles.kaggle_manager import DownloadData, SubmitResult from metagpt.schema import Message, Plan from metagpt.utils.common import remove_comments, create_func_config @@ -192,16 +193,6 @@ class MLEngineer(Role): ) debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]] cause_by = WriteCodeByGenerate - - if self.make_udfs and len(code.split('\n')) > 4: - # make and save user-defined function tools. - logger.warning(f"Making tools for task_id {self.plan.current_task_id}: \ - `{self.plan.current_task.instruction}` \n code {code}") - make_tools = MakeTools() - code_prompt = f"The following code is about {self.plan.current_task.instruction},\ - convert it to be a General Function, {code}" - tool_code = await make_tools.run(code_prompt) - make_tools.save(tool_code) else: logger.info("Write code with tools") schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas" @@ -219,6 +210,9 @@ class MLEngineer(Role): result, success = await self.execute_code.run(code) print(result) + # make tools for successful code and long code. + if success and self.make_udfs and len(code.split('\n')) > 4: + await self.make_tools(code=code) self.working_memory.add( Message(content=result, role="user", cause_by=ExecutePyCode) ) @@ -304,6 +298,39 @@ class MLEngineer(Role): def get_working_memories(self) -> List[Message]: return self.working_memory.get() + def reset(self): + """Restart role with the same goal.""" + self.plan = Plan(goal=self.plan.goal) + self.execute_code = ExecutePyCode() + + async def make_tools(self, code: str): + """Make user-defined functions(udfs, aka tools) for pure generation code. + + Args: + code (str): pure generation code by class WriteCodeByGenerate. + """ + logger.warning(f"Making tools for task_id {self.plan.current_task_id}: \ + `{self.plan.current_task.instruction}` \n code: \n {code}") + make_tools = MakeTools() + code_prompt = f"The following code is about {self.plan.current_task.instruction},\ + convert it to be a General Function, {code}" + tool_code = await make_tools.run(code_prompt) + # check tool_code by execute_code + logger.info(f"Checking task_id {self.plan.current_task_id} tool code by executor...") + _, success = await self.execute_code.run(tool_code) + make_tool_retries, make_tool_current_retry = 3, 1 + while not success: + tool_code = await make_tools.run(code_prompt) + _, success = await self.execute_code.run(tool_code) + if make_tool_current_retry > make_tool_retries: + logger.error(f"We have tried the maximum number of attempts {make_tool_retries}\ + and still have not created tools for task_id {self.plan.current_task_id} successfully,\ + we will skip it.") + break + # save successful tool code in udf + if success: + make_tools.save(tool_code) + if __name__ == "__main__": requirement = "Run data analysis on sklearn Iris dataset, include a plot" @@ -314,6 +341,12 @@ if __name__ == "__main__": async def main(requirement: str = requirement, auto_run: bool = True): role = MLEngineer(goal=requirement, auto_run=auto_run) + # make udfs + role.make_udfs = True + role.use_udfs = False + await role.run(requirement) + # use udfs + role.reset() role.make_udfs = False role.use_udfs = True await role.run(requirement)