autogluon small fix

This commit is contained in:
Rayhao 2024-09-10 19:20:24 -07:00
parent e3663f2322
commit a8ee20843b
2 changed files with 4 additions and 5 deletions

View file

@ -34,7 +34,7 @@ class GluonExperimenter(CustomExperimenter):
super().__init__(args, **kwargs)
self.framework = AGRunner(self.state)
def run_experiment(self):
async def run_experiment(self):
result = self.framework.run()
user_requirement = self.state["requirement"]
dev_preds = result["dev_preds"]
@ -44,5 +44,4 @@ class GluonExperimenter(CustomExperimenter):
"test_score": self.evaluate_predictions(test_preds, "test"),
}
results = [0, {"score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}]
self.save_result(results)
return results
self.save_result(results)

View file

@ -11,7 +11,7 @@ from expo.experimenter.autogluon import GluonExperimenter
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="")
parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy", "autoglu"])
parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy", "autogluon"])
get_di_args(parser)
get_mcts_args(parser)
get_aug_exp_args(parser)
@ -48,7 +48,7 @@ async def main(args):
experimenter = AugExperimenter(args)
elif args.exp_mode == "base":
experimenter = Experimenter(args)
elif args.exp_mode == "autoglu":
elif args.exp_mode == "autogluon":
experimenter = GluonExperimenter(args)
elif args.exp_mode == "custom":
experimenter = CustomExperimenter(args)