mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-08 03:22:37 +02:00
Fix bug in pruning loop + fix test
This commit is contained in:
parent
186f11eaaf
commit
7055d5ecf6
2 changed files with 29 additions and 25 deletions
|
|
@ -324,31 +324,34 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
||||||
print('Improving generalizations')
|
print('Improving generalizations')
|
||||||
level = 1
|
level = 1
|
||||||
while accuracy > self.target_accuracy:
|
while accuracy > self.target_accuracy:
|
||||||
try:
|
cells_previous_iter = self.cells
|
||||||
cells_previous_iter = self.cells
|
generalization_prev_iter = self._generalizations
|
||||||
generalization_prev_iter = self._generalizations
|
cells_by_id_prev = self._cells_by_id
|
||||||
cells_by_id_prev = self._cells_by_id
|
nodes = self._get_nodes_level(level)
|
||||||
nodes = self._get_nodes_level(level)
|
|
||||||
self._calculate_level_cells(level)
|
|
||||||
self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes)
|
|
||||||
|
|
||||||
self._calculate_generalizations()
|
try:
|
||||||
generalized = self._generalize(X_test, x_prepared_test, nodes, self.cells,
|
self._calculate_level_cells(level)
|
||||||
self._cells_by_id)
|
except TypeError as e:
|
||||||
accuracy = self.estimator.score(ArrayDataset(preprocessor.transform(generalized), y_test))
|
|
||||||
# if accuracy passed threshold roll back to previous iteration generalizations
|
|
||||||
if accuracy < self.target_accuracy:
|
|
||||||
self.cells = cells_previous_iter
|
|
||||||
self._generalizations = generalization_prev_iter
|
|
||||||
self._cells_by_id = cells_by_id_prev
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy))
|
|
||||||
level += 1
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
print(e)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes)
|
||||||
|
|
||||||
|
self._calculate_generalizations()
|
||||||
|
generalized = self._generalize(X_test, x_prepared_test, nodes, self.cells,
|
||||||
|
self._cells_by_id)
|
||||||
|
accuracy = self.estimator.score(ArrayDataset(preprocessor.transform(generalized), y_test))
|
||||||
|
# if accuracy passed threshold roll back to previous iteration generalizations
|
||||||
|
if accuracy < self.target_accuracy:
|
||||||
|
self.cells = cells_previous_iter
|
||||||
|
self._generalizations = generalization_prev_iter
|
||||||
|
self._cells_by_id = cells_by_id_prev
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy))
|
||||||
|
level += 1
|
||||||
|
|
||||||
|
|
||||||
# if accuracy below threshold, improve accuracy by removing features from generalization
|
# if accuracy below threshold, improve accuracy by removing features from generalization
|
||||||
elif accuracy < self.target_accuracy:
|
elif accuracy < self.target_accuracy:
|
||||||
print('Improving accuracy')
|
print('Improving accuracy')
|
||||||
|
|
@ -569,7 +572,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
||||||
features = self._categorical_data.columns
|
features = self._categorical_data.columns
|
||||||
for cell in self.cells:
|
for cell in self.cells:
|
||||||
new_cell = {'id': cell['id'], 'label': cell['label'], 'ranges': {}, 'categories': {}, 'hist': cell['hist'],
|
new_cell = {'id': cell['id'], 'label': cell['label'], 'ranges': {}, 'categories': {}, 'hist': cell['hist'],
|
||||||
'representative': None}
|
'untouched': [], 'representative': None}
|
||||||
for feature in features:
|
for feature in features:
|
||||||
if feature in self._one_hot_vector_features_to_features.keys():
|
if feature in self._one_hot_vector_features_to_features.keys():
|
||||||
# feature is categorical and should be mapped
|
# feature is categorical and should be mapped
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from numpy.testing import assert_almost_equal
|
||||||
|
|
||||||
from sklearn.compose import ColumnTransformer
|
from sklearn.compose import ColumnTransformer
|
||||||
|
|
||||||
from sklearn.datasets import load_boston, load_diabetes
|
from sklearn.datasets import load_boston, load_diabetes
|
||||||
|
|
@ -912,13 +914,12 @@ def test_blackbox_model():
|
||||||
gen.fit(dataset=train_dataset)
|
gen.fit(dataset=train_dataset)
|
||||||
transformed = gen.transform(dataset=ad)
|
transformed = gen.transform(dataset=ad)
|
||||||
gener = gen.generalizations
|
gener = gen.generalizations
|
||||||
expected_generalizations = {'ranges': {'0': [], '1': [], '2': [4.849999904632568, 5.049999952316284],
|
expected_generalizations = {'ranges': {'0': [], '1': [], '2': [4.849999904632568], '3': [0.7000000029802322]},
|
||||||
'3': [0.7000000029802322, 1.600000023841858]},
|
|
||||||
'categories': {},
|
'categories': {},
|
||||||
'untouched': []}
|
'untouched': []}
|
||||||
|
|
||||||
for key in expected_generalizations['ranges']:
|
for key in expected_generalizations['ranges']:
|
||||||
assert (set(expected_generalizations['ranges'][key]) == set(gener['ranges'][key]))
|
assert_almost_equal(expected_generalizations['ranges'][key], gener['ranges'][key])
|
||||||
for key in expected_generalizations['categories']:
|
for key in expected_generalizations['categories']:
|
||||||
assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]]) ==
|
assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]]) ==
|
||||||
set([frozenset(sl) for sl in gener['categories'][key]]))
|
set([frozenset(sl) for sl in gener['categories'][key]]))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue