diff --git a/expo/experimenter/autogluon.py b/expo/experimenter/autogluon.py index 478ecfc01..93dfdb4bc 100644 --- a/expo/experimenter/autogluon.py +++ b/expo/experimenter/autogluon.py @@ -1,17 +1,19 @@ from datetime import datetime -from autogluon.tabular import TabularDataset, TabularPredictor + from expo.experimenter.custom import CustomExperimenter class AGRunner: preset = "best_quality" - time_limit = 1000 # 1000s + time_limit = 1000 # 1000s def __init__(self, state=None): self.state = state self.datasets = self.state["datasets_dir"] def run(self): + from autogluon.tabular import TabularDataset, TabularPredictor + train_path = self.datasets["train"] dev_wo_target_path = self.datasets["dev_wo_target"] test_wo_target_path = self.datasets["test_wo_target"] @@ -21,7 +23,11 @@ class AGRunner: test_data = TabularDataset(test_wo_target_path) eval_metric = self.state["dataset_config"]["metric"].replace(" ", "_") # predictor = TabularPredictor(label=target_col, eval_metric=eval_metric, path="AutogluonModels/ag-{}-{}".format(self.state['task'], datetime.now().strftime("%y%m%d_%H%M"))).fit(train_data, presets=self.preset, time_limit=self.time_limit, fit_weighted_ensemble=False, num_gpus=1) - predictor = TabularPredictor(label=target_col, eval_metric=eval_metric, path="AutogluonModels/ag-{}-{}".format(self.state['task'], datetime.now().strftime("%y%m%d_%H%M"))).fit(train_data, num_gpus=1) + predictor = TabularPredictor( + label=target_col, + eval_metric=eval_metric, + path="AutogluonModels/ag-{}-{}".format(self.state["task"], datetime.now().strftime("%y%m%d_%H%M")), + ).fit(train_data, num_gpus=1) dev_preds = predictor.predict(dev_data) test_preds = predictor.predict(test_data) return {"test_preds": test_preds, "dev_preds": dev_preds} @@ -44,4 +50,4 @@ class GluonExperimenter(CustomExperimenter): "test_score": self.evaluate_predictions(test_preds, "test"), } results = [0, {"score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}] - self.save_result(results) \ No newline at end of file + self.save_result(results)