automatically update args.low_is_better for mle-bench

This commit is contained in:
Yizhou Chi 2024-10-17 17:47:41 +08:00
parent 06710fbc18
commit 852fbc58ee
2 changed files with 12 additions and 1 deletions

View file

@ -35,6 +35,15 @@ def get_mle_task_id(dataset_dir):
return dataset_dir.split("/")[-3]
def get_mle_is_lower_better(task):
from mlebench.data import get_leaderboard
from mlebench.registry import registry
competition = registry.get_competition(task)
competition_leaderboard = get_leaderboard(competition)
return competition.grader.is_lower_better(competition_leaderboard)
def get_mle_bench_requirements(dataset_dir, data_config, special_instruction, obfuscated=False):
work_dir = data_config["work_dir"]
task = get_mle_task_id(dataset_dir)

View file

@ -1,7 +1,7 @@
import argparse
import asyncio
from expo.data.custom_task import get_mle_task_id
from expo.data.custom_task import get_mle_is_lower_better, get_mle_task_id
from expo.experimenter.aug import AugExperimenter
from expo.experimenter.autogluon import GluonExperimenter
from expo.experimenter.autosklearn import AutoSklearnExperimenter
@ -33,6 +33,8 @@ def get_args(cmd=True):
args.eval_func = "mlebench"
args.from_scratch = True
args.task = get_mle_task_id(args.custom_dataset_dir)
args.low_is_better = get_mle_is_lower_better(args.task)
print("low_is_better:", args.low_is_better)
return args