mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
Extract common code to methods
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
570c6f8966
commit
ea8564bc4b
1 changed files with 39 additions and 61 deletions
|
|
@ -315,17 +315,10 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
self._attach_cells_representatives(x_prepared, used_x_train, y_train, nodes)
|
||||
|
||||
# self._cells currently holds the generalization created from the tree leaves
|
||||
self._calculate_generalizations(x_test)
|
||||
if self.generalize_using_transform:
|
||||
generalized = self._generalize_from_tree(x_test, x_prepared_test, nodes, self.cells, self._cells_by_id)
|
||||
else:
|
||||
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
|
||||
generalized = self._generalize(x_test, x_prepared_test, nodes)
|
||||
|
||||
# check accuracy
|
||||
if self.encoder:
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
|
||||
else:
|
||||
accuracy = self.estimator.score(ArrayDataset(generalized, y_test))
|
||||
accuracy = self._calculate_accuracy(generalized, y_test, self.estimator, self.encoder)
|
||||
print('Initial accuracy of model on generalized data, relative to original model predictions '
|
||||
'(base generalization derived from tree, before improvements): %f' % accuracy)
|
||||
|
||||
|
|
@ -348,17 +341,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
|
||||
self._attach_cells_representatives(x_prepared, used_x_train, y_train, nodes)
|
||||
|
||||
self._calculate_generalizations(x_test)
|
||||
if self.generalize_using_transform:
|
||||
generalized = self._generalize_from_tree(x_test, x_prepared_test, nodes, self.cells,
|
||||
self._cells_by_id)
|
||||
else:
|
||||
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
|
||||
|
||||
if self.encoder:
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
|
||||
else:
|
||||
accuracy = self.estimator.score(ArrayDataset(generalized, y_test))
|
||||
generalized = self._generalize(x_test, x_prepared_test, nodes)
|
||||
accuracy = self._calculate_accuracy(generalized, y_test, self.estimator, self.encoder)
|
||||
# if accuracy passed threshold roll back to previous iteration generalizations
|
||||
if accuracy < self.target_accuracy:
|
||||
self.cells = cells_previous_iter
|
||||
|
|
@ -381,17 +365,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
if removed_feature is None:
|
||||
break
|
||||
|
||||
self._calculate_generalizations(x_test)
|
||||
if self.generalize_using_transform:
|
||||
generalized = self._generalize_from_tree(x_test, x_prepared_test, nodes, self.cells,
|
||||
self._cells_by_id)
|
||||
else:
|
||||
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
|
||||
|
||||
if self.encoder:
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
|
||||
else:
|
||||
accuracy = self.estimator.score(ArrayDataset(generalized, y_test))
|
||||
generalized = self._generalize(x_test, x_prepared_test, nodes)
|
||||
accuracy = self._calculate_accuracy(generalized, y_test, self.estimator, self.encoder)
|
||||
print('Removed feature: %s, new relative accuracy: %f' % (removed_feature, accuracy))
|
||||
|
||||
# self._cells currently holds the chosen generalization based on target accuracy
|
||||
|
|
@ -918,6 +893,15 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
|
||||
return original_data_generalized
|
||||
|
||||
def _generalize(self, data, data_prepared, nodes):
|
||||
self._calculate_generalizations(data)
|
||||
if self.generalize_using_transform:
|
||||
generalized = self._generalize_from_tree(data, data_prepared, nodes, self.cells,
|
||||
self._cells_by_id)
|
||||
else:
|
||||
generalized = self._generalize_from_generalizations(data, self.generalizations)
|
||||
return generalized
|
||||
|
||||
@staticmethod
|
||||
def _map_to_ranges_categories(samples, ranges, categories):
|
||||
all_sample_indexes = []
|
||||
|
|
@ -987,21 +971,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
feature_data[feature],
|
||||
total)
|
||||
if feature_ncp > 0:
|
||||
# divide by accuracy gain
|
||||
new_cells = copy.deepcopy(self.cells)
|
||||
cells_by_id = copy.deepcopy(self._cells_by_id)
|
||||
GeneralizeToRepresentative._remove_feature_from_cells(new_cells, cells_by_id, feature)
|
||||
generalized = self._generalize_from_tree(original_data, prepared_data, nodes, new_cells,
|
||||
cells_by_id)
|
||||
if self.encoder:
|
||||
accuracy_gain = self.estimator.score(ArrayDataset(self.encoder.transform(generalized),
|
||||
labels)) - current_accuracy
|
||||
else:
|
||||
accuracy_gain = self.estimator.score(ArrayDataset(generalized, labels)) - current_accuracy
|
||||
if accuracy_gain < 0:
|
||||
accuracy_gain = 0
|
||||
if accuracy_gain != 0:
|
||||
feature_ncp = feature_ncp / accuracy_gain
|
||||
feature_ncp = self._normalize_ncp_by_accuracy_gain(original_data, prepared_data, nodes, feature,
|
||||
feature_ncp, labels, current_accuracy)
|
||||
|
||||
if feature_ncp < range_min:
|
||||
range_min = feature_ncp
|
||||
|
|
@ -1017,22 +988,9 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
feature_data[feature],
|
||||
total)
|
||||
if feature_ncp > 0:
|
||||
# divide by accuracy loss
|
||||
new_cells = copy.deepcopy(self.cells)
|
||||
cells_by_id = copy.deepcopy(self._cells_by_id)
|
||||
GeneralizeToRepresentative._remove_feature_from_cells(new_cells, cells_by_id, feature)
|
||||
generalized = self._generalize_from_tree(original_data, prepared_data, nodes, new_cells,
|
||||
cells_by_id)
|
||||
if self.encoder:
|
||||
accuracy_gain = self.estimator.score(ArrayDataset(self.encoder.transform(generalized),
|
||||
labels)) - current_accuracy
|
||||
else:
|
||||
accuracy_gain = self.estimator.score(ArrayDataset(generalized, labels)) - current_accuracy
|
||||
feature_ncp = self._normalize_ncp_by_accuracy_gain(original_data, prepared_data, nodes, feature,
|
||||
feature_ncp, labels, current_accuracy)
|
||||
|
||||
if accuracy_gain < 0:
|
||||
accuracy_gain = 0
|
||||
if accuracy_gain != 0:
|
||||
feature_ncp = feature_ncp / accuracy_gain
|
||||
if feature_ncp < range_min:
|
||||
range_min = feature_ncp
|
||||
remove_feature = feature
|
||||
|
|
@ -1062,6 +1020,21 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
feature_ncp += cell_ncp
|
||||
return feature_ncp
|
||||
|
||||
def _normalize_ncp_by_accuracy_gain(self, original_data, prepared_data, nodes, feature, feature_ncp, labels,
|
||||
current_accuracy):
|
||||
new_cells = copy.deepcopy(self.cells)
|
||||
cells_by_id = copy.deepcopy(self._cells_by_id)
|
||||
GeneralizeToRepresentative._remove_feature_from_cells(new_cells, cells_by_id, feature)
|
||||
generalized = self._generalize_from_tree(original_data, prepared_data, nodes, new_cells,
|
||||
cells_by_id)
|
||||
accuracy = self._calculate_accuracy(generalized, labels, self.estimator, self.encoder)
|
||||
accuracy_gain = accuracy - current_accuracy
|
||||
if accuracy_gain < 0:
|
||||
accuracy_gain = 0
|
||||
if accuracy_gain != 0:
|
||||
feature_ncp = feature_ncp / accuracy_gain
|
||||
return feature_ncp
|
||||
|
||||
def _calculate_generalizations(self, samples: Optional[pd.DataFrame] = None):
|
||||
ranges, range_representatives = self._calculate_ranges(self.cells)
|
||||
categories, category_representatives = self._calculate_categories(self.cells)
|
||||
|
|
@ -1281,3 +1254,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
|
||||
for feature in to_remove:
|
||||
del generalizations['categories'][feature]
|
||||
|
||||
@staticmethod
|
||||
def _calculate_accuracy(generalized, y_test, estimator, encoder):
|
||||
generalized_data = encoder.transform(generalized) if encoder else generalized
|
||||
return estimator.score(ArrayDataset(generalized_data, y_test))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue