mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-02 14:45:13 +02:00
generalize_using_transform=False supported
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
710aae4083
commit
4541ee60a2
2 changed files with 32 additions and 4 deletions
|
|
@ -93,8 +93,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
|
||||||
self.train_only_features_to_minimize = train_only_features_to_minimize
|
self.train_only_features_to_minimize = train_only_features_to_minimize
|
||||||
self.is_regression = is_regression
|
self.is_regression = is_regression
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
# self.generalize_using_transform = generalize_using_transform
|
self.generalize_using_transform = generalize_using_transform
|
||||||
self.generalize_using_transform = False
|
|
||||||
self._ncp = 0.0
|
self._ncp = 0.0
|
||||||
self._feature_data = {}
|
self._feature_data = {}
|
||||||
self._categorical_values = {}
|
self._categorical_values = {}
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ def test_minimizer_params_not_transform(data):
|
||||||
model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)
|
model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)
|
||||||
model.fit(ArrayDataset(X, y))
|
model.fit(ArrayDataset(X, y))
|
||||||
|
|
||||||
gen = GeneralizeToRepresentative(model, cells=cells)
|
gen = GeneralizeToRepresentative(model, cells=cells, generalize_using_transform=False)
|
||||||
gen.calculate_ncp(X)
|
gen.calculate_ncp(X)
|
||||||
ncp = gen.ncp
|
ncp = gen.ncp
|
||||||
assert (ncp > 0.0)
|
assert (ncp > 0.0)
|
||||||
|
|
@ -158,7 +158,7 @@ def test_minimizer_fit_not_transform(data):
|
||||||
if predictions.shape[1] > 1:
|
if predictions.shape[1] > 1:
|
||||||
predictions = np.argmax(predictions, axis=1)
|
predictions = np.argmax(predictions, axis=1)
|
||||||
target_accuracy = 0.5
|
target_accuracy = 0.5
|
||||||
gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy)
|
gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy, generalize_using_transform=False)
|
||||||
train_dataset = ArrayDataset(X, predictions, features_names=features)
|
train_dataset = ArrayDataset(X, predictions, features_names=features)
|
||||||
|
|
||||||
gen.fit(dataset=train_dataset)
|
gen.fit(dataset=train_dataset)
|
||||||
|
|
@ -1043,3 +1043,32 @@ def test_untouched():
|
||||||
assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]])
|
assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]])
|
||||||
== set([frozenset(sl) for sl in gener['categories'][key]]))
|
== set([frozenset(sl) for sl in gener['categories'][key]]))
|
||||||
assert (set(expected_generalizations['untouched']) == set(gener['untouched']))
|
assert (set(expected_generalizations['untouched']) == set(gener['untouched']))
|
||||||
|
|
||||||
|
|
||||||
|
def test_errors():
|
||||||
|
features = ['age', 'height']
|
||||||
|
X = np.array([[23, 165],
|
||||||
|
[45, 158],
|
||||||
|
[56, 123],
|
||||||
|
[67, 154],
|
||||||
|
[45, 149],
|
||||||
|
[42, 166],
|
||||||
|
[73, 172],
|
||||||
|
[94, 168],
|
||||||
|
[69, 175],
|
||||||
|
[24, 181],
|
||||||
|
[18, 190]])
|
||||||
|
y = np.array([1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0])
|
||||||
|
base_est = DecisionTreeClassifier(random_state=0, min_samples_split=2,
|
||||||
|
min_samples_leaf=1)
|
||||||
|
model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)
|
||||||
|
model.fit(ArrayDataset(X, y))
|
||||||
|
ad = ArrayDataset(X)
|
||||||
|
predictions = model.predict(ad)
|
||||||
|
if predictions.shape[1] > 1:
|
||||||
|
predictions = np.argmax(predictions, axis=1)
|
||||||
|
gen = GeneralizeToRepresentative(model, generalize_using_transform=False)
|
||||||
|
train_dataset = ArrayDataset(X, predictions, features_names=features)
|
||||||
|
gen.fit(dataset=train_dataset)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
gen.transform(X)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue