Add option for non-stratified split in minimizer

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2022-12-21 09:20:45 +02:00 committed by abigailgold
parent 89bdcfc00e
commit ba88bc09ba

View file

@ -230,10 +230,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
if self.train_only_features_to_minimize:
used_data = x_QI
if self.is_regression:
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, random_state=14)
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4,
random_state=14)
else:
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), stratify=dataset.get_labels(), test_size=0.4,
random_state=18)
try:
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(),
stratify=dataset.get_labels(), test_size=0.4,
random_state=14)
except ValueError:
print('Could not stratify split due to uncommon class value, doing unstratified split instead')
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4,
random_state=14)
X_train_QI = X_train.loc[:, self.features_to_minimize]
X_test_QI = X_test.loc[:, self.features_to_minimize]