mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
fix import error delete seed
This commit is contained in:
parent
5d2de4d0ec
commit
c4fe056bca
1 changed files with 12 additions and 14 deletions
|
|
@ -9,15 +9,8 @@ def custom_scorer(y_true, y_pred, metric_name):
|
|||
return evaluate_score(y_pred, y_true, metric_name)
|
||||
|
||||
|
||||
def create_autosklearn_scorer(metric_name):
|
||||
return make_scorer(
|
||||
name=metric_name, score_func=partial(custom_scorer, metric_name=metric_name)
|
||||
)
|
||||
|
||||
|
||||
class ASRunner:
|
||||
time_limit = 300
|
||||
seed = 42
|
||||
|
||||
def __init__(self, state=None):
|
||||
self.state = state
|
||||
|
|
@ -25,12 +18,19 @@ class ASRunner:
|
|||
try:
|
||||
import autosklearn.classification
|
||||
import autosklearn.regression
|
||||
from autosklearn.metrics import make_scorer
|
||||
import autosklearn.metrics
|
||||
|
||||
self.autosklearn = autosklearn
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"autosklearn not found or system not supported, please check it first"
|
||||
)
|
||||
|
||||
def create_autosklearn_scorer(self, metric_name):
|
||||
return self.autosklearn.metrics.make_scorer(
|
||||
name=metric_name, score_func=partial(custom_scorer, metric_name=metric_name)
|
||||
)
|
||||
|
||||
def run(self):
|
||||
train_path = self.datasets["train"]
|
||||
dev_wo_target_path = self.datasets["dev_wo_target"]
|
||||
|
|
@ -45,24 +45,22 @@ class ASRunner:
|
|||
y_train = train_data[target_col]
|
||||
|
||||
if eval_metric == "rmse":
|
||||
automl = autosklearn.regression.AutoSklearnRegressor(
|
||||
automl = self.autosklearn.regression.AutoSklearnRegressor(
|
||||
time_left_for_this_task=self.time_limit,
|
||||
per_run_time_limit=60,
|
||||
metric=create_autosklearn_scorer(eval_metric),
|
||||
metric=self.create_autosklearn_scorer(eval_metric),
|
||||
memory_limit=8192,
|
||||
seed=self.seed,
|
||||
tmp_folder="AutosklearnModels/as-{}-{}".format(
|
||||
self.state["task"], datetime.now().strftime("%y%m%d_%H%M")
|
||||
),
|
||||
n_jobs=-1,
|
||||
)
|
||||
elif eval_metric in ["f1", "f1 weighted"]:
|
||||
automl = autosklearn.classification.AutoSklearnClassifier(
|
||||
automl = self.autosklearn.classification.AutoSklearnClassifier(
|
||||
time_left_for_this_task=self.time_limit,
|
||||
per_run_time_limit=60,
|
||||
metric=create_autosklearn_scorer(eval_metric),
|
||||
metric=self.create_autosklearn_scorer(eval_metric),
|
||||
memory_limit=8192,
|
||||
seed=self.seed,
|
||||
tmp_folder="AutosklearnModels/as-{}-{}".format(
|
||||
self.state["task"], datetime.now().strftime("%y%m%d_%H%M")
|
||||
),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue