mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-03 16:52:38 +02:00
Support pytorch models in data minimization (#85)
* Support pytorch models in data minimization Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
a40484e0c9
commit
26addd192f
2 changed files with 76 additions and 3 deletions
|
|
@ -256,6 +256,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
# Going to fit
|
||||
# (currently not dealing with option to fit with only X and y and no estimator)
|
||||
if self.estimator and dataset and dataset.get_samples() is not None and dataset.get_labels() is not None:
|
||||
dtype = dataset.get_samples().dtype
|
||||
x = pd.DataFrame(dataset.get_samples(), columns=self._features)
|
||||
if not self.features_to_minimize:
|
||||
self.features_to_minimize = self._features
|
||||
|
|
@ -340,7 +341,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
|
||||
|
||||
# check accuracy
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized).astype(dtype), y_test))
|
||||
print('Initial accuracy of model on generalized data, relative to original model predictions '
|
||||
'(base generalization derived from tree, before improvements): %f' % accuracy)
|
||||
|
||||
|
|
@ -370,7 +371,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
else:
|
||||
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
|
||||
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized).astype(dtype),
|
||||
y_test))
|
||||
# if accuracy passed threshold roll back to previous iteration generalizations
|
||||
if accuracy < self.target_accuracy:
|
||||
self.cells = cells_previous_iter
|
||||
|
|
@ -399,7 +401,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
|||
self._cells_by_id)
|
||||
else:
|
||||
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
|
||||
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized).astype(dtype),
|
||||
y_test))
|
||||
print('Removed feature: %s, new relative accuracy: %f' % (removed_feature, accuracy))
|
||||
|
||||
# self._cells currently holds the chosen generalization based on target accuracy
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue