update prompt to include dataset info path

This commit is contained in:
Yizhou Chi 2024-09-05 13:01:00 +08:00
parent 45d176b48b
commit d27a48adb2

View file

@ -16,9 +16,8 @@ Perform data analysis, data preprocessing, feature engineering, and modeling to
Report {metric} on the eval data. Do not plot or make any visualizations.
"""
TASK_PROMPT = """\
# User requirement
{user_requirement}
DI_INSTRUCTION = """\
**Attention**
1. Please do not leak the target label in any form during training.
2. Dev and Test sets do not have the target column.
@ -39,14 +38,19 @@ Print the training set performance in the last step. Write in this format:
print("Train score:", train_score)
```
# Output dir
{output_dir}
"""
TASK_PROMPT = """\
# User requirement
{user_requirement}
{additional_instruction}
# Data dir
training (with labels): {train_path}
dev (without labels): {dev_path}
testing (without labels): {test_path}
# Output dir
{output_dir}
dataset description: {data_info_path} (You can use this file to get additional information about the dataset)
"""
@ -132,7 +136,12 @@ def create_dataset_dict(dataset):
return dataset_dict
def generate_task_requirement(task_name, data_config):
def generate_di_instruction(output_dir):
additional_instruction = DI_INSTRUCTION.format(output_dir=output_dir)
return additional_instruction
def generate_task_requirement(task_name, data_config, is_di=True):
user_requirement = get_user_requirement(task_name, data_config)
split_dataset_path = get_split_dataset_path(task_name, data_config)
train_path = split_dataset_path["train"]
@ -140,12 +149,19 @@ def generate_task_requirement(task_name, data_config):
test_path = split_dataset_path["test_wo_target"]
work_dir = data_config["work_dir"]
output_dir = f"{work_dir}/{task_name}"
datasets_dir = data_config["datasets_dir"]
data_info_path = f"{datasets_dir}/{task_name}/dataset_info.json"
if is_di:
additional_instruction = generate_di_instruction(output_dir)
else:
additional_instruction = ""
user_requirement = TASK_PROMPT.format(
user_requirement=user_requirement,
train_path=train_path,
dev_path=dev_path,
test_path=test_path,
output_dir=output_dir,
additional_instruction=additional_instruction,
data_info_path=data_info_path,
)
print(user_requirement)
return user_requirement