diff --git a/expo/MCTS.py b/expo/MCTS.py index 8778554ed..2ce559ae0 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -43,6 +43,8 @@ def create_initial_state(task, start_task_id, data_config, args): task = get_mle_task_id(args.custom_dataset_dir) else: dataset_config = data_config["datasets"][task] + if dataset_config["metric"] == "rmse": + args.low_is_better = True datasets_dir = get_split_dataset_path(task, data_config) requirement = generate_task_requirement( task, data_config, is_di=True, special_instruction=args.special_instruction