From 3dde4664f40f9ae9f92441369a7cab0c88f7e68d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Wed, 21 Feb 2024 11:03:32 +0800 Subject: [PATCH] add sales_forecast in machine_learning. --- examples/mi/machine_learning.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/examples/mi/machine_learning.py b/examples/mi/machine_learning.py index a8ab5051e..a76561a37 100644 --- a/examples/mi/machine_learning.py +++ b/examples/mi/machine_learning.py @@ -3,10 +3,36 @@ import fire from metagpt.roles.mi.interpreter import Interpreter -async def main(auto_run: bool = True): - requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy." +DATA_DIR = "examples/mi/data" +requirements = { + "wine": "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.", + + # sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data + "sales_forecast": f""" + # Goal + Use time series regression machine learning to make predictions for Dept sales of the stores as accurate as possible. + + # Datasets Available + - train_data: {DATA_DIR}/WalmartSalesForecast/new_train.csv + - test_data: {DATA_DIR}/WalmartSalesForecast/new_test.csv + - additional data: {DATA_DIR}/WalmartSalesForecast/features.csv; To merge on train, test data. + - stores data: {DATA_DIR}/WalmartSalesForecast/stores.csv; To merge on train, test data. + + # Metric + The metric of the competition is weighted mean absolute error (WMAE) for test data. + + # Notice + - *print* key variables to get more information for next task step. + - Perform data analysis by plotting sales trends, holiday effects, distribution of sales across stores/departments using box/violin on the train data. + - Make sure the DataFrame.dtypes must be int, float or bool, and drop date column. + - Plot scatter plots of groud truth and predictions on test data. + """ +} + + +async def main(auto_run: bool = True, use_case: str = 'wine'): mi = Interpreter(auto_run=auto_run) - await mi.run(requirement) + await mi.run(requirements[use_case]) if __name__ == "__main__":