diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index 22b480caf..c063268c8 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -8,9 +8,12 @@ from expo.MCTS import MCTS class MCTSExperimenter(Experimenter): result_path: str = "results/mcts" - start_task_id = 2 def __init__(self, args, tree_mode=None, **kwargs): + if args.special_instruction == "image": + self.start_task_id = 1 # start from datapreprocessing if it is image task + else: + self.start_task_id = args.start_task_id super().__init__(args, **kwargs) self.tree_mode = tree_mode diff --git a/expo/run_experiment.py b/expo/run_experiment.py index 49d058f13..15be27d60 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -31,6 +31,7 @@ def get_mcts_args(parser): parser.set_defaults(load_tree=False) parser.add_argument("--rollouts", type=int, default=5) parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true") + parser.add_argument("--start_task_id", type=int, default=2) def get_aug_exp_args(parser):