From ba88bc09ba9d2e99478e946f302187792dcd3a69 Mon Sep 17 00:00:00 2001 From: abigailt Date: Wed, 21 Dec 2022 09:20:45 +0200 Subject: [PATCH] Add option for non-stratified split in minimizer Signed-off-by: abigailt --- apt/minimization/minimizer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/apt/minimization/minimizer.py b/apt/minimization/minimizer.py index 79983f7..6afdded 100644 --- a/apt/minimization/minimizer.py +++ b/apt/minimization/minimizer.py @@ -230,10 +230,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM if self.train_only_features_to_minimize: used_data = x_QI if self.is_regression: - X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, random_state=14) + X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, + random_state=14) else: - X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), stratify=dataset.get_labels(), test_size=0.4, - random_state=18) + try: + X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), + stratify=dataset.get_labels(), test_size=0.4, + random_state=14) + except ValueError: + print('Could not stratify split due to uncommon class value, doing unstratified split instead') + X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, + random_state=14) X_train_QI = X_train.loc[:, self.features_to_minimize] X_test_QI = X_test.loc[:, self.features_to_minimize]