mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-01 07:46:21 +02:00
Fix misclassification of categorical features with no generalizations (now appear under the 'untouched' category)
This commit is contained in:
parent
fe676fa426
commit
186f11eaaf
2 changed files with 40 additions and 3 deletions
|
|
@ -62,9 +62,9 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
cells: Optional[list] = None, categorical_features: Optional[Union[np.ndarray, list]] = None,
|
||||
features_to_minimize: Optional[Union[np.ndarray, list]] = None, train_only_QI: Optional[bool] = True,
|
||||
is_regression: Optional[bool] = False):
|
||||
if issubclass(estimator.__class__, Model):
|
||||
self.estimator = estimator
|
||||
else:
|
||||
|
||||
self.estimator = estimator
|
||||
if estimator is not None and not issubclass(estimator.__class__, Model):
|
||||
if is_regression:
|
||||
self.estimator = SklearnRegressor(estimator)
|
||||
else:
|
||||
|
|
@ -832,6 +832,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
self._generalizations = {'ranges': GeneralizeToRepresentative._calculate_ranges(self.cells),
|
||||
'categories': GeneralizeToRepresentative._calculate_categories(self.cells),
|
||||
'untouched': GeneralizeToRepresentative._calculate_untouched(self.cells)}
|
||||
self._remove_categorical_untouched(self._generalizations)
|
||||
|
||||
def _find_range_count(self, samples, ranges):
|
||||
samples_df = pd.DataFrame(samples, columns=self._categorical_data.columns)
|
||||
|
|
@ -988,3 +989,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
del cell['categories'][feature]
|
||||
cell['untouched'].append(feature)
|
||||
cells_by_id[cell['id']] = cell.copy()
|
||||
|
||||
@staticmethod
|
||||
def _remove_categorical_untouched(generalizations):
|
||||
to_remove = []
|
||||
for feature in generalizations['categories'].keys():
|
||||
category_sizes = [len(g) if len(g) > 1 else 0 for g in generalizations['categories'][feature]]
|
||||
if sum(category_sizes) == 0:
|
||||
if 'untouched' not in generalizations:
|
||||
generalizations['untouched'] = []
|
||||
generalizations['untouched'].append(feature)
|
||||
to_remove.append(feature)
|
||||
|
||||
for feature in to_remove:
|
||||
del generalizations['categories'][feature]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue