From b776c7309bc64c0ec62c8efa4acbe93fdbd1a75f Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Tue, 10 Sep 2024 15:30:23 +0800 Subject: [PATCH] add random tree search --- expo/Greedy.py | 10 ++++++++++ expo/data/dataset.py | 2 +- expo/experimenter/mcts.py | 10 ++++++---- expo/run_experiment.py | 8 ++++++-- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/expo/Greedy.py b/expo/Greedy.py index f6f60db01..8c8d865cd 100644 --- a/expo/Greedy.py +++ b/expo/Greedy.py @@ -1,3 +1,5 @@ +import random + from expo.MCTS import MCTS @@ -7,3 +9,11 @@ class Greedy(MCTS): return self.root_node all_children = [child for children in self.children.values() for child in children] return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0)) + + +class Random(MCTS): + def best_child(self): + if len(self.children) == 0: + return self.root_node + all_children = [child for children in self.children.values() for child in children] + return random.choice(all_children) diff --git a/expo/data/dataset.py b/expo/data/dataset.py index 88528eb5c..43ac8ee0d 100644 --- a/expo/data/dataset.py +++ b/expo/data/dataset.py @@ -23,7 +23,7 @@ DI_INSTRUCTION = """\ 2. Test set does not have the target column. 3. You should perform transformations on train, dev, and test sets at the same time (it's a good idea to define functions for this and avoid code repetition). 4. If labels are transformed during training, they should be transformed back to the original format before saving the predictions. -5. You could split the training set further to make cross-validation and hyperparameter tuning. +5. You could utilize dev set to improve the model. ## Saving Dev and Test Predictions 1. Save the prediction results of BOTH the dev set and test set in `dev_predictions.csv` and `test_predictions.csv` respectively in the output directory. diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index 9bf7306c4..fbe2f35f1 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -1,19 +1,21 @@ from expo.evaluation.visualize_mcts import get_tree_text from expo.experimenter.experimenter import Experimenter -from expo.Greedy import Greedy +from expo.Greedy import Greedy, Random from expo.MCTS import MCTS class MCTSExperimenter(Experimenter): result_path: str = "results/mcts" - def __init__(self, args, greedy=False, **kwargs): + def __init__(self, args, tree_mode=None, **kwargs): super().__init__(args, **kwargs) - self.greedy = greedy + self.tree_mode = tree_mode async def run_experiment(self): - if self.greedy: + if self.tree_mode == "greedy": mcts = Greedy(root_node=None, max_depth=5) + elif self.tree_mode == "random": + mcts = Random(root_node=None, max_depth=5) else: mcts = MCTS(root_node=None, max_depth=5) best_nodes = await mcts.search( diff --git a/expo/run_experiment.py b/expo/run_experiment.py index 83237741a..b68607d79 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -10,7 +10,9 @@ from expo.experimenter.mcts import MCTSExperimenter def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--name", type=str, default="") - parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy"]) + parser.add_argument( + "--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy", "random"] + ) get_di_args(parser) get_mcts_args(parser) get_aug_exp_args(parser) @@ -42,7 +44,9 @@ async def main(args): if args.exp_mode == "mcts": experimenter = MCTSExperimenter(args) elif args.exp_mode == "greedy": - experimenter = MCTSExperimenter(args, greedy=True) + experimenter = MCTSExperimenter(args, tree_mode="greedy") + elif args.exp_mode == "random": + experimenter = MCTSExperimenter(args, tree_mode="random") elif args.exp_mode == "aug": experimenter = AugExperimenter(args) elif args.exp_mode == "base":