clean up input argument

This commit is contained in:
Yizhou Chi 2024-10-17 10:11:31 +08:00
parent 38daf24c33
commit a46f575361
4 changed files with 17 additions and 21 deletions

View file

@ -29,16 +29,14 @@ def initialize_di_root_node(state, reflection: bool = True):
return role, Node(parent=None, state=state, action=None, value=0)
def create_initial_state(
task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str, args
):
def create_initial_state(task, start_task_id, data_config, args):
external_eval = args.external_eval
if args.custom_dataset_dir:
dataset_config = None
datasets_dir = args.custom_dataset_dir
requirement = get_mle_bench_requirements(
args.custom_dataset_dir, data_config, special_instruction=special_instruction
args.custom_dataset_dir, data_config, special_instruction=args.special_instruction
)
exp_pool_path = None
# external_eval = False # make sure external eval is false if custom dataset is used
@ -46,20 +44,22 @@ def create_initial_state(
else:
dataset_config = data_config["datasets"][task]
datasets_dir = get_split_dataset_path(task, data_config)
requirement = generate_task_requirement(task, data_config, is_di=True, special_instruction=special_instruction)
requirement = generate_task_requirement(
task, data_config, is_di=True, special_instruction=args.special_instruction
)
exp_pool_path = get_exp_pool_path(task, data_config, pool_name="ds_analysis_pool")
initial_state = {
"task": task,
"work_dir": data_config["work_dir"],
"node_dir": os.path.join(data_config["work_dir"], data_config["role_dir"], f"{task}{name}"),
"node_dir": os.path.join(data_config["work_dir"], data_config["role_dir"], f"{task}{args.name}"),
"dataset_config": dataset_config,
"datasets_dir": datasets_dir, # won't be used if external eval is used
"exp_pool_path": exp_pool_path,
"requirement": requirement,
"has_run": False,
"start_task_id": start_task_id,
"low_is_better": low_is_better,
"low_is_better": args.low_is_better,
"role_timeout": args.role_timeout,
"external_eval": external_eval,
"custom_dataset_dir": args.custom_dataset_dir,

View file

@ -21,9 +21,7 @@ class CustomExperimenter(Experimenter):
self.task,
start_task_id=1,
data_config=self.data_config,
low_is_better=self.low_is_better,
name=self.name,
special_instruction=self.args.special_instruction,
args=self.args,
)
def run_experiment(self):

View file

@ -24,9 +24,6 @@ class Experimenter:
self.args.task,
start_task_id=self.start_task_id,
data_config=self.data_config,
low_is_better=self.args.low_is_better,
name=self.args.name,
special_instruction=self.args.special_instruction,
args=self.args,
)

View file

@ -24,9 +24,16 @@ def get_args(cmd=True):
get_mcts_args(parser)
get_aug_exp_args(parser)
if cmd:
return parser.parse_args()
args = parser.parse_args()
else:
return parser.parse_args("")
args = parser.parse_args("")
if args.custom_dataset_dir:
args.external_eval = False
args.eval_func = "mlebench"
args.from_scratch = True
args.task = get_mle_task_id(args.custom_dataset_dir)
return args
def get_mcts_args(parser):
@ -65,12 +72,6 @@ def get_di_args(parser):
async def main(args):
if args.custom_dataset_dir:
args.external_eval = False
args.eval_func = "mlebench"
args.from_scratch = True
args.task = get_mle_task_id(args.custom_dataset_dir)
if args.exp_mode == "mcts":
experimenter = MCTSExperimenter(args)
elif args.exp_mode == "greedy":