fix import error delete seed

This commit is contained in:
duiyipan 2024-09-14 14:58:22 +08:00
parent 5d2de4d0ec
commit c4fe056bca

View file

@ -9,15 +9,8 @@ def custom_scorer(y_true, y_pred, metric_name):
return evaluate_score(y_pred, y_true, metric_name)
def create_autosklearn_scorer(metric_name):
return make_scorer(
name=metric_name, score_func=partial(custom_scorer, metric_name=metric_name)
)
class ASRunner:
time_limit = 300
seed = 42
def __init__(self, state=None):
self.state = state
@ -25,12 +18,19 @@ class ASRunner:
try:
import autosklearn.classification
import autosklearn.regression
from autosklearn.metrics import make_scorer
import autosklearn.metrics
self.autosklearn = autosklearn
except ImportError:
raise ImportError(
"autosklearn not found or system not supported, please check it first"
)
def create_autosklearn_scorer(self, metric_name):
return self.autosklearn.metrics.make_scorer(
name=metric_name, score_func=partial(custom_scorer, metric_name=metric_name)
)
def run(self):
train_path = self.datasets["train"]
dev_wo_target_path = self.datasets["dev_wo_target"]
@ -45,24 +45,22 @@ class ASRunner:
y_train = train_data[target_col]
if eval_metric == "rmse":
automl = autosklearn.regression.AutoSklearnRegressor(
automl = self.autosklearn.regression.AutoSklearnRegressor(
time_left_for_this_task=self.time_limit,
per_run_time_limit=60,
metric=create_autosklearn_scorer(eval_metric),
metric=self.create_autosklearn_scorer(eval_metric),
memory_limit=8192,
seed=self.seed,
tmp_folder="AutosklearnModels/as-{}-{}".format(
self.state["task"], datetime.now().strftime("%y%m%d_%H%M")
),
n_jobs=-1,
)
elif eval_metric in ["f1", "f1 weighted"]:
automl = autosklearn.classification.AutoSklearnClassifier(
automl = self.autosklearn.classification.AutoSklearnClassifier(
time_left_for_this_task=self.time_limit,
per_run_time_limit=60,
metric=create_autosklearn_scorer(eval_metric),
metric=self.create_autosklearn_scorer(eval_metric),
memory_limit=8192,
seed=self.seed,
tmp_folder="AutosklearnModels/as-{}-{}".format(
self.state["task"], datetime.now().strftime("%y%m%d_%H%M")
),