diff --git a/expo/experimenter/autosklearn.py b/expo/experimenter/autosklearn.py index 22aa2a132..602b8385a 100644 --- a/expo/experimenter/autosklearn.py +++ b/expo/experimenter/autosklearn.py @@ -9,15 +9,8 @@ def custom_scorer(y_true, y_pred, metric_name): return evaluate_score(y_pred, y_true, metric_name) -def create_autosklearn_scorer(metric_name): - return make_scorer( - name=metric_name, score_func=partial(custom_scorer, metric_name=metric_name) - ) - - class ASRunner: time_limit = 300 - seed = 42 def __init__(self, state=None): self.state = state @@ -25,12 +18,19 @@ class ASRunner: try: import autosklearn.classification import autosklearn.regression - from autosklearn.metrics import make_scorer + import autosklearn.metrics + + self.autosklearn = autosklearn except ImportError: raise ImportError( "autosklearn not found or system not supported, please check it first" ) + def create_autosklearn_scorer(self, metric_name): + return self.autosklearn.metrics.make_scorer( + name=metric_name, score_func=partial(custom_scorer, metric_name=metric_name) + ) + def run(self): train_path = self.datasets["train"] dev_wo_target_path = self.datasets["dev_wo_target"] @@ -45,24 +45,22 @@ class ASRunner: y_train = train_data[target_col] if eval_metric == "rmse": - automl = autosklearn.regression.AutoSklearnRegressor( + automl = self.autosklearn.regression.AutoSklearnRegressor( time_left_for_this_task=self.time_limit, per_run_time_limit=60, - metric=create_autosklearn_scorer(eval_metric), + metric=self.create_autosklearn_scorer(eval_metric), memory_limit=8192, - seed=self.seed, tmp_folder="AutosklearnModels/as-{}-{}".format( self.state["task"], datetime.now().strftime("%y%m%d_%H%M") ), n_jobs=-1, ) elif eval_metric in ["f1", "f1 weighted"]: - automl = autosklearn.classification.AutoSklearnClassifier( + automl = self.autosklearn.classification.AutoSklearnClassifier( time_left_for_this_task=self.time_limit, per_run_time_limit=60, - metric=create_autosklearn_scorer(eval_metric), + metric=self.create_autosklearn_scorer(eval_metric), memory_limit=8192, - seed=self.seed, tmp_folder="AutosklearnModels/as-{}-{}".format( self.state["task"], datetime.now().strftime("%y%m%d_%H%M") ),