diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index fa9acadbc..3c1853fd5 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -175,6 +175,8 @@ class MLEngineer(Role): logger.warning("Writing code with user-defined function tools...") logger.info(f"Local user defined function as following:\ \n{json.dumps(list(UDFS_YAML.keys()), indent=2, ensure_ascii=False)}") + # set task_type to `udf` + self.plan.current_task.task_type = 'udf' tool_context, code = await WriteCodeWithTools(schema_path=UDFS_YAML).run( context=context, plan=self.plan, @@ -184,6 +186,7 @@ class MLEngineer(Role): cause_by = WriteCodeWithTools else: logger.info("Write code with pure generation") + # TODO: 添加基于current_task.instruction-code_path的k-v缓存 code = await WriteCodeByGenerate().run( context=context, plan=self.plan, temperature=0.0 )