diff --git a/expo/dataset.py b/expo/dataset.py index fee1199a9..f7e0301b5 100644 --- a/expo/dataset.py +++ b/expo/dataset.py @@ -16,9 +16,8 @@ Perform data analysis, data preprocessing, feature engineering, and modeling to Report {metric} on the eval data. Do not plot or make any visualizations. """ -TASK_PROMPT = """\ -# User requirement -{user_requirement} + +DI_INSTRUCTION = """\ **Attention** 1. Please do not leak the target label in any form during training. 2. Dev and Test sets do not have the target column. @@ -39,14 +38,19 @@ Print the training set performance in the last step. Write in this format: print("Train score:", train_score) ``` +# Output dir +{output_dir} +""" + +TASK_PROMPT = """\ +# User requirement +{user_requirement} +{additional_instruction} # Data dir training (with labels): {train_path} dev (without labels): {dev_path} testing (without labels): {test_path} - -# Output dir -{output_dir} - +dataset description: {data_info_path} (You can use this file to get additional information about the dataset) """ @@ -132,7 +136,12 @@ def create_dataset_dict(dataset): return dataset_dict -def generate_task_requirement(task_name, data_config): +def generate_di_instruction(output_dir): + additional_instruction = DI_INSTRUCTION.format(output_dir=output_dir) + return additional_instruction + + +def generate_task_requirement(task_name, data_config, is_di=True): user_requirement = get_user_requirement(task_name, data_config) split_dataset_path = get_split_dataset_path(task_name, data_config) train_path = split_dataset_path["train"] @@ -140,12 +149,19 @@ def generate_task_requirement(task_name, data_config): test_path = split_dataset_path["test_wo_target"] work_dir = data_config["work_dir"] output_dir = f"{work_dir}/{task_name}" + datasets_dir = data_config["datasets_dir"] + data_info_path = f"{datasets_dir}/{task_name}/dataset_info.json" + if is_di: + additional_instruction = generate_di_instruction(output_dir) + else: + additional_instruction = "" user_requirement = TASK_PROMPT.format( user_requirement=user_requirement, train_path=train_path, dev_path=dev_path, test_path=test_path, - output_dir=output_dir, + additional_instruction=additional_instruction, + data_info_path=data_info_path, ) print(user_requirement) return user_requirement