mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
Add option for non-stratified split in minimizer
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
89bdcfc00e
commit
ba88bc09ba
1 changed files with 10 additions and 3 deletions
|
|
@ -230,10 +230,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
if self.train_only_features_to_minimize:
|
||||
used_data = x_QI
|
||||
if self.is_regression:
|
||||
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4, random_state=14)
|
||||
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4,
|
||||
random_state=14)
|
||||
else:
|
||||
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), stratify=dataset.get_labels(), test_size=0.4,
|
||||
random_state=18)
|
||||
try:
|
||||
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(),
|
||||
stratify=dataset.get_labels(), test_size=0.4,
|
||||
random_state=14)
|
||||
except ValueError:
|
||||
print('Could not stratify split due to uncommon class value, doing unstratified split instead')
|
||||
X_train, X_test, y_train, y_test = train_test_split(x, dataset.get_labels(), test_size=0.4,
|
||||
random_state=14)
|
||||
|
||||
X_train_QI = X_train.loc[:, self.features_to_minimize]
|
||||
X_test_QI = X_test.loc[:, self.features_to_minimize]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue