From a26c849b5087184cb1902f35ae74d3f5e0e280ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Thu, 14 Mar 2024 10:26:12 +0800 Subject: [PATCH] restore WalmartSalesForecast example. --- examples/di/machine_learning.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/di/machine_learning.py b/examples/di/machine_learning.py index a58735831..c674e66e8 100644 --- a/examples/di/machine_learning.py +++ b/examples/di/machine_learning.py @@ -2,11 +2,21 @@ import fire from metagpt.roles.di.data_interpreter import DataInterpreter +WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy." -async def main(auto_run: bool = True): - requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy." - di = DataInterpreter(auto_run=auto_run) - await di.run(requirement) +DATA_DIR = "path/to/your/data" +# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data +SALES_FORECAST_REQ = f"""Train a model to predict sales for each department in every store (split the last 40 weeks records as validation dataset, the others is train dataset), include plot total sales trends, print metric and plot scatter plots of +groud truth and predictions on validation data. Dataset is {DATA_DIR}/train.csv, the metric is weighted mean absolute error (WMAE) for test data. Notice: *print* key variables to get more information for next task step. +""" + +REQUIREMENTS = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ} + + +async def main(use_case: str = "wine"): + mi = DataInterpreter() + requirement = REQUIREMENTS[use_case] + await mi.run(requirement) if __name__ == "__main__":