Support pytorch models in data minimization (#85)

* Support pytorch models in data minimization

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailgold 2023-09-21 17:48:15 +03:00 committed by GitHub
parent a40484e0c9
commit 26addd192f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 76 additions and 3 deletions

View file

@ -256,6 +256,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# Going to fit
# (currently not dealing with option to fit with only X and y and no estimator)
if self.estimator and dataset and dataset.get_samples() is not None and dataset.get_labels() is not None:
dtype = dataset.get_samples().dtype
x = pd.DataFrame(dataset.get_samples(), columns=self._features)
if not self.features_to_minimize:
self.features_to_minimize = self._features
@ -340,7 +341,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
# check accuracy
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized).astype(dtype), y_test))
print('Initial accuracy of model on generalized data, relative to original model predictions '
'(base generalization derived from tree, before improvements): %f' % accuracy)
@ -370,7 +371,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
else:
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized).astype(dtype),
y_test))
# if accuracy passed threshold roll back to previous iteration generalizations
if accuracy < self.target_accuracy:
self.cells = cells_previous_iter
@ -399,7 +401,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._cells_by_id)
else:
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized).astype(dtype),
y_test))
print('Removed feature: %s, new relative accuracy: %f' % (removed_feature, accuracy))
# self._cells currently holds the chosen generalization based on target accuracy