diff --git a/expo/MCTS.py b/expo/MCTS.py index cfb21a61c..1eb8a131c 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -274,7 +274,7 @@ class MCTS: # data_path root_node: Node = None children: dict = {} - max_depth: int = 5 + max_depth: int = None c_explore: float = 1.4 c_unvisited: float = 0.8 node_order: list = [] diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index 37fc7a071..a42566366 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -29,7 +29,7 @@ class MCTSExperimenter(Experimenter): async def run_experiment(self): use_fixed_insights = self.args.use_fixed_insights - depth = 5 + depth = self.args.max_depth if self.tree_mode == "greedy": mcts = Greedy(root_node=None, max_depth=depth, use_fixed_insights=use_fixed_insights) elif self.tree_mode == "random": diff --git a/expo/run_experiment.py b/expo/run_experiment.py index 71529b955..4e6b41fd7 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -44,6 +44,7 @@ def get_mcts_args(parser): parser.set_defaults(external_eval=True) parser.add_argument("--eval_func", type=str, default="sela", choices=["sela", "mlebench"]) parser.add_argument("--custom_dataset_dir", type=str, default=None) + parser.add_argument("--max_depth", type=int, default=4) def get_aug_exp_args(parser):