diff --git a/apt/minimization/minimizer.py b/apt/minimization/minimizer.py index 6afdded..9becaf0 100644 --- a/apt/minimization/minimizer.py +++ b/apt/minimization/minimizer.py @@ -236,11 +236,11 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM try: X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), stratify=dataset.get_labels(), test_size=0.4, - random_state=14) + random_state=18) except ValueError: print('Could not stratify split due to uncommon class value, doing unstratified split instead') X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, - random_state=14) + random_state=18) X_train_QI = X_train.loc[:, self.features_to_minimize] X_test_QI = X_test.loc[:, self.features_to_minimize]