diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 7e5cc8caf..092229ec9 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -148,7 +148,16 @@ class MLEngineer(Role): ) logger.info(f"new code \n{code}") cause_by = DebugCode - elif not self.use_tools or self.plan.current_task.task_type in ("other", "udf"): + elif not self.use_tools or self.plan.current_task.task_type == 'other': + logger.info("Write code with pure generation") + # TODO: 添加基于current_task.instruction-code_path的k-v缓存 + code = await WriteCodeByGenerate().run( + context=context, plan=self.plan, temperature=0.0 + ) + debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]] + cause_by = WriteCodeByGenerate + else: + logger.info("Write code with tools") if self.use_udfs: # use user-defined function tools. from metagpt.tools.functions.libs.udf import UDFS_YAML @@ -165,24 +174,14 @@ class MLEngineer(Role): debug_context = tool_context cause_by = WriteCodeWithTools else: - logger.info("Write code with pure generation") - # TODO: 添加基于current_task.instruction-code_path的k-v缓存 - code = await WriteCodeByGenerate().run( - context=context, plan=self.plan, temperature=0.0 + schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas" + tool_context, code = await WriteCodeWithTools(schema_path=schema_path).run( + context=context, + plan=self.plan, + column_info=self.data_desc.get("column_info", ""), ) - debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]] - cause_by = WriteCodeByGenerate - else: - logger.info("Write code with tools") - schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas" - tool_context, code = await WriteCodeWithTools(schema_path=schema_path).run( - context=context, - plan=self.plan, - column_info=self.data_desc.get("column_info", ""), - ) - debug_context = tool_context - cause_by = WriteCodeWithTools - + debug_context = tool_context + cause_by = WriteCodeWithTools self.working_memory.add( Message(content=code, role="assistant", cause_by=cause_by) ) @@ -346,10 +345,10 @@ if __name__ == "__main__": # data_path = f"{DATA_PATH}/santander-customer-transaction-prediction" # requirement = f"This is a customers financial dataset. Your goal is to predict which customers will make a specific transaction in the future. The target column is target. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report F1 Score on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv' ." - data_path = f"{DATA_PATH}/house-prices-advanced-regression-techniques" - requirement = f"This is a house price dataset, your goal is to predict the sale price of a property based on its features. The target column is SalePrice. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report RMSE between the logarithm of the predicted value and the logarithm of the observed sales price on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'." - save_dir = "" - # save_dir = DATA_PATH / "output" / "2023-12-14_20-40-34" + # data_path = f"{DATA_PATH}/house-prices-advanced-regression-techniques" + # requirement = f"This is a house price dataset, your goal is to predict the sale price of a property based on its features. The target column is SalePrice. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report RMSE between the logarithm of the predicted value and the logarithm of the observed sales price on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'." + # save_dir = "" + # # save_dir = DATA_PATH / "output" / "2023-12-14_20-40-34" async def main(requirement: str = requirement, auto_run: bool = True, use_tools: bool = False, use_code_steps: bool = False, save_dir: str = ""): """