This commit is contained in:
刘棒棒 2024-02-21 21:33:11 +08:00
parent 662fbd7e55
commit fc40174802

View file

@ -2,29 +2,30 @@ import fire
from metagpt.roles.mi.interpreter import Interpreter
DATA_DIR = "examples/mi/data"
WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data,
# new_train, new_test from train.csv.
# DATA_DIR = "your/path/to/data"
DATA_DIR = "examples/mi/data/WalmartSalesForecast2"
# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data
SALES_FORECAST_REQ = f"""
# Goal
Use time series regression machine learning to make predictions for Dept sales of the stores as accurate as possible.
Train a model to predict sales for each department in every store (split the last 40 weeks records as validation dataset,
the others is train dataset), include plot sales trends, holiday effects, distribution of sales across stores/departments,
using box on the train dataset, print metric and plot scatter plots of groud truth and predictions on validation data.
save predictions on test data.
# Datasets Available
- train_data: {DATA_DIR}/WalmartSalesForecast/new_train.csv
- test_data: {DATA_DIR}/WalmartSalesForecast/new_test.csv
- additional data: {DATA_DIR}/WalmartSalesForecast/features.csv; To merge on train, test data.
- stores data: {DATA_DIR}/WalmartSalesForecast/stores.csv; To merge on train, test data.
- train_data: {DATA_DIR}/train.csv
- test_data: {DATA_DIR}/test.csv, no label data.
- additional data: {DATA_DIR}/features.csv
- stores data: {DATA_DIR}/stores.csv
# Metric
The metric of the competition is weighted mean absolute error (WMAE) for test data.
# Notice
- *print* key variables to get more information for next task step.
- Perform data analysis by plotting sales trends, holiday effects, distribution of sales across stores/departments using box/violin on the train data.
- Make sure the DataFrame.dtypes must be int, float or bool, and drop date column.
- Plot scatter plots of groud truth and predictions on test data.
- Only When you fit the model, make the DataFrame.dtypes to be int, float or bool, and drop date column.
"""
requirements = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ}
@ -32,7 +33,12 @@ requirements = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ}
async def main(auto_run: bool = True, use_case: str = "wine"):
mi = Interpreter(auto_run=auto_run)
await mi.run(requirements[use_case])
if use_case == "wine":
requirement = requirements[use_case]
else:
assert DATA_DIR != "your/path/to/data", f"Please set DATA_DIR for the use_case: {use_case}!"
requirement = requirements[use_case]
await mi.run(requirement)
if __name__ == "__main__":