From a46f5753612c9e5e5e6a948f309873f994b16174 Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Thu, 17 Oct 2024 10:11:31 +0800 Subject: [PATCH] clean up input argument --- expo/MCTS.py | 14 +++++++------- expo/experimenter/custom.py | 4 +--- expo/experimenter/experimenter.py | 3 --- expo/run_experiment.py | 17 +++++++++-------- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/expo/MCTS.py b/expo/MCTS.py index 1eb8a131c..8778554ed 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -29,16 +29,14 @@ def initialize_di_root_node(state, reflection: bool = True): return role, Node(parent=None, state=state, action=None, value=0) -def create_initial_state( - task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str, args -): +def create_initial_state(task, start_task_id, data_config, args): external_eval = args.external_eval if args.custom_dataset_dir: dataset_config = None datasets_dir = args.custom_dataset_dir requirement = get_mle_bench_requirements( - args.custom_dataset_dir, data_config, special_instruction=special_instruction + args.custom_dataset_dir, data_config, special_instruction=args.special_instruction ) exp_pool_path = None # external_eval = False # make sure external eval is false if custom dataset is used @@ -46,20 +44,22 @@ def create_initial_state( else: dataset_config = data_config["datasets"][task] datasets_dir = get_split_dataset_path(task, data_config) - requirement = generate_task_requirement(task, data_config, is_di=True, special_instruction=special_instruction) + requirement = generate_task_requirement( + task, data_config, is_di=True, special_instruction=args.special_instruction + ) exp_pool_path = get_exp_pool_path(task, data_config, pool_name="ds_analysis_pool") initial_state = { "task": task, "work_dir": data_config["work_dir"], - "node_dir": os.path.join(data_config["work_dir"], data_config["role_dir"], f"{task}{name}"), + "node_dir": os.path.join(data_config["work_dir"], data_config["role_dir"], f"{task}{args.name}"), "dataset_config": dataset_config, "datasets_dir": datasets_dir, # won't be used if external eval is used "exp_pool_path": exp_pool_path, "requirement": requirement, "has_run": False, "start_task_id": start_task_id, - "low_is_better": low_is_better, + "low_is_better": args.low_is_better, "role_timeout": args.role_timeout, "external_eval": external_eval, "custom_dataset_dir": args.custom_dataset_dir, diff --git a/expo/experimenter/custom.py b/expo/experimenter/custom.py index 92b7dafa2..f245499ca 100644 --- a/expo/experimenter/custom.py +++ b/expo/experimenter/custom.py @@ -21,9 +21,7 @@ class CustomExperimenter(Experimenter): self.task, start_task_id=1, data_config=self.data_config, - low_is_better=self.low_is_better, - name=self.name, - special_instruction=self.args.special_instruction, + args=self.args, ) def run_experiment(self): diff --git a/expo/experimenter/experimenter.py b/expo/experimenter/experimenter.py index 417adabad..4a0b8413e 100644 --- a/expo/experimenter/experimenter.py +++ b/expo/experimenter/experimenter.py @@ -24,9 +24,6 @@ class Experimenter: self.args.task, start_task_id=self.start_task_id, data_config=self.data_config, - low_is_better=self.args.low_is_better, - name=self.args.name, - special_instruction=self.args.special_instruction, args=self.args, ) diff --git a/expo/run_experiment.py b/expo/run_experiment.py index c977b4dc9..be891814d 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -24,9 +24,16 @@ def get_args(cmd=True): get_mcts_args(parser) get_aug_exp_args(parser) if cmd: - return parser.parse_args() + args = parser.parse_args() else: - return parser.parse_args("") + args = parser.parse_args("") + + if args.custom_dataset_dir: + args.external_eval = False + args.eval_func = "mlebench" + args.from_scratch = True + args.task = get_mle_task_id(args.custom_dataset_dir) + return args def get_mcts_args(parser): @@ -65,12 +72,6 @@ def get_di_args(parser): async def main(args): - if args.custom_dataset_dir: - args.external_eval = False - args.eval_func = "mlebench" - args.from_scratch = True - args.task = get_mle_task_id(args.custom_dataset_dir) - if args.exp_mode == "mcts": experimenter = MCTSExperimenter(args) elif args.exp_mode == "greedy":