From fc4017480205104f281f0367ef83acc433375a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Wed, 21 Feb 2024 21:33:11 +0800 Subject: [PATCH] chore. --- examples/mi/machine_learning.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/examples/mi/machine_learning.py b/examples/mi/machine_learning.py index 689335db3..5f9d5b0cd 100644 --- a/examples/mi/machine_learning.py +++ b/examples/mi/machine_learning.py @@ -2,29 +2,30 @@ import fire from metagpt.roles.mi.interpreter import Interpreter -DATA_DIR = "examples/mi/data" WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy." -# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data, -# new_train, new_test from train.csv. +# DATA_DIR = "your/path/to/data" +DATA_DIR = "examples/mi/data/WalmartSalesForecast2" +# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data SALES_FORECAST_REQ = f""" # Goal -Use time series regression machine learning to make predictions for Dept sales of the stores as accurate as possible. +Train a model to predict sales for each department in every store (split the last 40 weeks records as validation dataset, +the others is train dataset), include plot sales trends, holiday effects, distribution of sales across stores/departments, +using box on the train dataset, print metric and plot scatter plots of groud truth and predictions on validation data. +save predictions on test data. # Datasets Available -- train_data: {DATA_DIR}/WalmartSalesForecast/new_train.csv -- test_data: {DATA_DIR}/WalmartSalesForecast/new_test.csv -- additional data: {DATA_DIR}/WalmartSalesForecast/features.csv; To merge on train, test data. -- stores data: {DATA_DIR}/WalmartSalesForecast/stores.csv; To merge on train, test data. +- train_data: {DATA_DIR}/train.csv +- test_data: {DATA_DIR}/test.csv, no label data. +- additional data: {DATA_DIR}/features.csv +- stores data: {DATA_DIR}/stores.csv # Metric The metric of the competition is weighted mean absolute error (WMAE) for test data. # Notice - *print* key variables to get more information for next task step. -- Perform data analysis by plotting sales trends, holiday effects, distribution of sales across stores/departments using box/violin on the train data. -- Make sure the DataFrame.dtypes must be int, float or bool, and drop date column. -- Plot scatter plots of groud truth and predictions on test data. +- Only When you fit the model, make the DataFrame.dtypes to be int, float or bool, and drop date column. """ requirements = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ} @@ -32,7 +33,12 @@ requirements = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ} async def main(auto_run: bool = True, use_case: str = "wine"): mi = Interpreter(auto_run=auto_run) - await mi.run(requirements[use_case]) + if use_case == "wine": + requirement = requirements[use_case] + else: + assert DATA_DIR != "your/path/to/data", f"Please set DATA_DIR for the use_case: {use_case}!" + requirement = requirements[use_case] + await mi.run(requirement) if __name__ == "__main__":