diff --git a/expo/data/custom_task.py b/expo/data/custom_task.py index e904e9496..fe366d7ea 100644 --- a/expo/data/custom_task.py +++ b/expo/data/custom_task.py @@ -34,14 +34,16 @@ def get_mle_task_id(dataset_dir): return dataset_dir.split("/")[-3] -def get_mle_bench_requirements(dataset_dir, data_config, obfuscated=False, special_instruction=""): +def get_mle_bench_requirements(dataset_dir, data_config, special_instruction, obfuscated=False): work_dir = data_config["work_dir"] task = get_mle_task_id(dataset_dir) output_dir = f"{work_dir}/{task}" final_output_dir = f"{work_dir}/submission" os.makedirs(output_dir, exist_ok=True) - special_instruction = SPECIAL_INSTRUCTIONS[special_instruction] - + if special_instruction: + special_instruction = SPECIAL_INSTRUCTIONS[special_instruction] + else: + special_instruction = "" if obfuscated: instructions = INSTRUCTIONS_OBFUSCATED.format(dataset_dir=dataset_dir, output_dir=final_output_dir) task_file = "description_obfuscated.md"