allow max depth passing in

This commit is contained in:
Yizhou Chi 2024-10-16 10:53:37 +08:00
parent 7794b99005
commit 989a3b4299
3 changed files with 3 additions and 2 deletions

View file

@ -274,7 +274,7 @@ class MCTS:
# data_path
root_node: Node = None
children: dict = {}
max_depth: int = 5
max_depth: int = None
c_explore: float = 1.4
c_unvisited: float = 0.8
node_order: list = []

View file

@ -29,7 +29,7 @@ class MCTSExperimenter(Experimenter):
async def run_experiment(self):
use_fixed_insights = self.args.use_fixed_insights
depth = 5
depth = self.args.max_depth
if self.tree_mode == "greedy":
mcts = Greedy(root_node=None, max_depth=depth, use_fixed_insights=use_fixed_insights)
elif self.tree_mode == "random":

View file

@ -44,6 +44,7 @@ def get_mcts_args(parser):
parser.set_defaults(external_eval=True)
parser.add_argument("--eval_func", type=str, default="sela", choices=["sela", "mlebench"])
parser.add_argument("--custom_dataset_dir", type=str, default=None)
parser.add_argument("--max_depth", type=int, default=4)
def get_aug_exp_args(parser):