From a8ee20843b7492743c3fcb434c3423b130447ead Mon Sep 17 00:00:00 2001 From: Rayhao Date: Tue, 10 Sep 2024 19:20:24 -0700 Subject: [PATCH] autogluon small fix --- expo/experimenter/autogluon.py | 5 ++--- expo/run_experiment.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/expo/experimenter/autogluon.py b/expo/experimenter/autogluon.py index b33411773..478ecfc01 100644 --- a/expo/experimenter/autogluon.py +++ b/expo/experimenter/autogluon.py @@ -34,7 +34,7 @@ class GluonExperimenter(CustomExperimenter): super().__init__(args, **kwargs) self.framework = AGRunner(self.state) - def run_experiment(self): + async def run_experiment(self): result = self.framework.run() user_requirement = self.state["requirement"] dev_preds = result["dev_preds"] @@ -44,5 +44,4 @@ class GluonExperimenter(CustomExperimenter): "test_score": self.evaluate_predictions(test_preds, "test"), } results = [0, {"score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}] - self.save_result(results) - return results \ No newline at end of file + self.save_result(results) \ No newline at end of file diff --git a/expo/run_experiment.py b/expo/run_experiment.py index 74f9c6e57..cfdd295b2 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -11,7 +11,7 @@ from expo.experimenter.autogluon import GluonExperimenter def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--name", type=str, default="") - parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy", "autoglu"]) + parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy", "autogluon"]) get_di_args(parser) get_mcts_args(parser) get_aug_exp_args(parser) @@ -48,7 +48,7 @@ async def main(args): experimenter = AugExperimenter(args) elif args.exp_mode == "base": experimenter = Experimenter(args) - elif args.exp_mode == "autoglu": + elif args.exp_mode == "autogluon": experimenter = GluonExperimenter(args) elif args.exp_mode == "custom": experimenter = CustomExperimenter(args)