update use_udfs.

This commit is contained in:
刘棒棒 2023-12-21 13:34:31 +08:00
parent 82dce58e4e
commit e8f5ce0f0a

View file

@ -148,7 +148,16 @@ class MLEngineer(Role):
)
logger.info(f"new code \n{code}")
cause_by = DebugCode
elif not self.use_tools or self.plan.current_task.task_type in ("other", "udf"):
elif not self.use_tools or self.plan.current_task.task_type == 'other':
logger.info("Write code with pure generation")
# TODO: 添加基于current_task.instruction-code_path的k-v缓存
code = await WriteCodeByGenerate().run(
context=context, plan=self.plan, temperature=0.0
)
debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]]
cause_by = WriteCodeByGenerate
else:
logger.info("Write code with tools")
if self.use_udfs:
# use user-defined function tools.
from metagpt.tools.functions.libs.udf import UDFS_YAML
@ -165,24 +174,14 @@ class MLEngineer(Role):
debug_context = tool_context
cause_by = WriteCodeWithTools
else:
logger.info("Write code with pure generation")
# TODO: 添加基于current_task.instruction-code_path的k-v缓存
code = await WriteCodeByGenerate().run(
context=context, plan=self.plan, temperature=0.0
schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas"
tool_context, code = await WriteCodeWithTools(schema_path=schema_path).run(
context=context,
plan=self.plan,
column_info=self.data_desc.get("column_info", ""),
)
debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]]
cause_by = WriteCodeByGenerate
else:
logger.info("Write code with tools")
schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas"
tool_context, code = await WriteCodeWithTools(schema_path=schema_path).run(
context=context,
plan=self.plan,
column_info=self.data_desc.get("column_info", ""),
)
debug_context = tool_context
cause_by = WriteCodeWithTools
debug_context = tool_context
cause_by = WriteCodeWithTools
self.working_memory.add(
Message(content=code, role="assistant", cause_by=cause_by)
)
@ -346,10 +345,10 @@ if __name__ == "__main__":
# data_path = f"{DATA_PATH}/santander-customer-transaction-prediction"
# requirement = f"This is a customers financial dataset. Your goal is to predict which customers will make a specific transaction in the future. The target column is target. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report F1 Score on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv' ."
data_path = f"{DATA_PATH}/house-prices-advanced-regression-techniques"
requirement = f"This is a house price dataset, your goal is to predict the sale price of a property based on its features. The target column is SalePrice. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report RMSE between the logarithm of the predicted value and the logarithm of the observed sales price on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
save_dir = ""
# save_dir = DATA_PATH / "output" / "2023-12-14_20-40-34"
# data_path = f"{DATA_PATH}/house-prices-advanced-regression-techniques"
# requirement = f"This is a house price dataset, your goal is to predict the sale price of a property based on its features. The target column is SalePrice. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report RMSE between the logarithm of the predicted value and the logarithm of the observed sales price on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
# save_dir = ""
# # save_dir = DATA_PATH / "output" / "2023-12-14_20-40-34"
async def main(requirement: str = requirement, auto_run: bool = True, use_tools: bool = False, use_code_steps: bool = False, save_dir: str = ""):
"""