From 9c54113f777990174bacd0ea789435f53982e71a Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Fri, 11 Oct 2024 10:41:33 +0800 Subject: [PATCH] make sure image task starting from datapreprocessing --- expo/experimenter/mcts.py | 5 ++++- expo/run_experiment.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index 22b480caf..c063268c8 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -8,9 +8,12 @@ from expo.MCTS import MCTS class MCTSExperimenter(Experimenter): result_path: str = "results/mcts" - start_task_id = 2 def __init__(self, args, tree_mode=None, **kwargs): + if args.special_instruction == "image": + self.start_task_id = 1 # start from datapreprocessing if it is image task + else: + self.start_task_id = args.start_task_id super().__init__(args, **kwargs) self.tree_mode = tree_mode diff --git a/expo/run_experiment.py b/expo/run_experiment.py index 49d058f13..15be27d60 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -31,6 +31,7 @@ def get_mcts_args(parser): parser.set_defaults(load_tree=False) parser.add_argument("--rollouts", type=int, default=5) parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true") + parser.add_argument("--start_task_id", type=int, default=2) def get_aug_exp_args(parser):