diff --git a/apt/minimization/minimizer.py b/apt/minimization/minimizer.py index 79983f7..6afdded 100644 --- a/apt/minimization/minimizer.py +++ b/apt/minimization/minimizer.py @@ -230,10 +230,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM if self.train_only_features_to_minimize: used_data = x_QI if self.is_regression: - X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, random_state=14) + X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, + random_state=14) else: - X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), stratify=dataset.get_labels(), test_size=0.4, - random_state=18) + try: + X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), + stratify=dataset.get_labels(), test_size=0.4, + random_state=14) + except ValueError: + print('Could not stratify split due to uncommon class value, doing unstratified split instead') + X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, + random_state=14) X_train_QI = X_train.loc[:, self.features_to_minimize] X_test_QI = X_test.loc[:, self.features_to_minimize]