Review comments

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-12-24 17:17:59 -05:00
parent a3d294af2d
commit 686969eb86

View file

@ -102,7 +102,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self.features_to_minimize = features_to_minimize
self.feature_slices = feature_slices
if self.feature_slices:
self.all_one_hot_features = set([str(feature) for encoded in self.feature_slices for feature in encoded])
self.all_one_hot_features = {str(feature) for encoded in self.feature_slices for feature in encoded}
else:
self.all_one_hot_features = set()
self.train_only_features_to_minimize = train_only_features_to_minimize
@ -398,6 +398,14 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._ncp_scores.generalizations_score = self.calculate_ncp(x_test_dataset)
else:
print('No fitting was performed as some information was missing')
if not self.estimator:
print('No estimator provided')
elif not dataset:
print('No data provided')
elif dataset.get_samples() is None:
print('No samples provided')
elif dataset.get_labels() is None:
print('No labels provided')
# Return the transformer
return self
@ -736,6 +744,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
feature_value = 0
elif range['end'] is None and range['start'] > 0:
feature_value = 1
else:
raise ValueError('Illegal range for 1-hot encoded feature')
new_cell['categories'][feature] = [feature_value]
# need to add other columns that represent same 1-hot encoded feature
@ -1279,7 +1289,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
def _get_other_features_in_encoding(feature, feature_slices):
for encoded in feature_slices:
if feature in encoded:
return (list(set(encoded) - set([feature]))), encoded
return (list(set(encoded) - {feature})), encoded
return [], []
@staticmethod