mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
Review comments
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
a3d294af2d
commit
686969eb86
1 changed files with 12 additions and 2 deletions
|
|
@ -102,7 +102,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
self.features_to_minimize = features_to_minimize
|
||||
self.feature_slices = feature_slices
|
||||
if self.feature_slices:
|
||||
self.all_one_hot_features = set([str(feature) for encoded in self.feature_slices for feature in encoded])
|
||||
self.all_one_hot_features = {str(feature) for encoded in self.feature_slices for feature in encoded}
|
||||
else:
|
||||
self.all_one_hot_features = set()
|
||||
self.train_only_features_to_minimize = train_only_features_to_minimize
|
||||
|
|
@ -398,6 +398,14 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
self._ncp_scores.generalizations_score = self.calculate_ncp(x_test_dataset)
|
||||
else:
|
||||
print('No fitting was performed as some information was missing')
|
||||
if not self.estimator:
|
||||
print('No estimator provided')
|
||||
elif not dataset:
|
||||
print('No data provided')
|
||||
elif dataset.get_samples() is None:
|
||||
print('No samples provided')
|
||||
elif dataset.get_labels() is None:
|
||||
print('No labels provided')
|
||||
|
||||
# Return the transformer
|
||||
return self
|
||||
|
|
@ -736,6 +744,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
feature_value = 0
|
||||
elif range['end'] is None and range['start'] > 0:
|
||||
feature_value = 1
|
||||
else:
|
||||
raise ValueError('Illegal range for 1-hot encoded feature')
|
||||
new_cell['categories'][feature] = [feature_value]
|
||||
|
||||
# need to add other columns that represent same 1-hot encoded feature
|
||||
|
|
@ -1279,7 +1289,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
def _get_other_features_in_encoding(feature, feature_slices):
|
||||
for encoded in feature_slices:
|
||||
if feature in encoded:
|
||||
return (list(set(encoded) - set([feature]))), encoded
|
||||
return (list(set(encoded) - {feature})), encoded
|
||||
return [], []
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue