From 4cb2028c7240f8be607a9b9f57cdfb47bd197117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Tue, 19 Dec 2023 10:24:57 +0800 Subject: [PATCH] update for make tools test. --- metagpt/roles/ml_engineer.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 1361c566f..75c403226 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -48,7 +48,8 @@ class MLEngineer(Role): self.plan = Plan(goal=goal) self.use_tools = False - self.make_tools = True + self.make_udfs = False + self.use_udfs = False self.use_code_steps = False self.execute_code = ExecutePyCode() self.auto_run = auto_run @@ -168,14 +169,19 @@ class MLEngineer(Role): logger.info(f"new code \n{code}") cause_by = DebugCode elif not self.use_tools or self.plan.current_task.task_type == "other": - logger.info("Write code with pure generation") - code = await WriteCodeByGenerate().run( - context=context, plan=self.plan, temperature=0.0 - ) - debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]] - cause_by = WriteCodeByGenerate - if self.make_tools: - # make and save tools. + if self.use_udfs: + # use user-defined function tools. + pass + else: + logger.info("Write code with pure generation") + code = await WriteCodeByGenerate().run( + context=context, plan=self.plan, temperature=0.0 + ) + debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]] + cause_by = WriteCodeByGenerate + + if self.make_udfs: + # make and save user-defined function tools. make_tools = MakeTools() tool_code = await make_tools.run(code) make_tools.save(tool_code) @@ -291,6 +297,7 @@ if __name__ == "__main__": async def main(requirement: str = requirement, auto_run: bool = True): role = MLEngineer(goal=requirement, auto_run=auto_run) + role.make_udfs = True await role.run(requirement) fire.Fire(main)