make sure image task starting from datapreprocessing

This commit is contained in:
Yizhou Chi 2024-10-11 10:41:33 +08:00
parent cdfb413f9d
commit 9c54113f77
2 changed files with 5 additions and 1 deletions

View file

@ -8,9 +8,12 @@ from expo.MCTS import MCTS
class MCTSExperimenter(Experimenter):
result_path: str = "results/mcts"
start_task_id = 2
def __init__(self, args, tree_mode=None, **kwargs):
if args.special_instruction == "image":
self.start_task_id = 1 # start from datapreprocessing if it is image task
else:
self.start_task_id = args.start_task_id
super().__init__(args, **kwargs)
self.tree_mode = tree_mode

View file

@ -31,6 +31,7 @@ def get_mcts_args(parser):
parser.set_defaults(load_tree=False)
parser.add_argument("--rollouts", type=int, default=5)
parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true")
parser.add_argument("--start_task_id", type=int, default=2)
def get_aug_exp_args(parser):