diff --git a/apt/minimization/minimizer.py b/apt/minimization/minimizer.py index 53ef7c5..acf4fd8 100644 --- a/apt/minimization/minimizer.py +++ b/apt/minimization/minimizer.py @@ -346,8 +346,9 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM # if accuracy above threshold, improve generalization if accuracy > self.target_accuracy: print('Improving generalizations') - self._level = 1 + self._level = 0 while accuracy > self.target_accuracy: + self._level += 1 cells_previous_iter = self.cells generalization_prev_iter = self._generalizations cells_by_id_prev = self._cells_by_id @@ -373,7 +374,6 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM break else: print('Pruned tree to level: %d, new relative accuracy: %f' % (self._level, accuracy)) - self._level += 1 # if accuracy below threshold, improve accuracy by removing features from generalization elif accuracy < self.target_accuracy: diff --git a/tests/test_minimizer.py b/tests/test_minimizer.py index 9c8fc28..abff834 100644 --- a/tests/test_minimizer.py +++ b/tests/test_minimizer.py @@ -181,8 +181,8 @@ def compare_generalizations(gener, expected_generalizations): == set(gener['range_representatives'][key])) if 'category_representatives' in expected_generalizations: for key in expected_generalizations['category_representatives']: - assert (set([frozenset(sl) for sl in expected_generalizations['category_representatives'][key]]) - == set([frozenset(sl) for sl in gener['category_representatives'][key]])) + assert (set(expected_generalizations['category_representatives'][key]) + == set(gener['category_representatives'][key])) def check_features(features, expected_generalizations, transformed, x, pandas=False): @@ -961,6 +961,47 @@ def test_minimizer_ndarray_one_hot(): assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF) +def test_minimizer_ndarray_one_hot_gen(): + x_train = np.array([[23, 0, 1, 165], + [45, 0, 1, 158], + [56, 1, 0, 123], + [67, 0, 1, 154], + [45, 1, 0, 149], + [42, 1, 0, 166], + [73, 0, 1, 172], + [94, 0, 1, 168], + [69, 0, 1, 175], + [24, 1, 0, 181], + [18, 1, 0, 190]]) + y_train = np.array([1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0]) + + model = DecisionTreeClassifier() + model.fit(x_train, y_train) + predictions = model.predict(x_train) + + features = ['0', '1', '2', '3'] + QI = [0, 1, 2] + QI_slices = [[1, 2]] + target_accuracy = 0.2 + gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy, feature_slices=QI_slices, + features_to_minimize=QI) + gen.fit(dataset=ArrayDataset(x_train, predictions)) + transformed = gen.transform(dataset=ArrayDataset(x_train)) + gener = gen.generalizations + expected_generalizations = {'categories': {'1': [[0, 1]], '2': [[0, 1]]}, + 'category_representatives': {'1': [0], '2': [1]}, + 'range_representatives': {'0': []}, 'ranges': {'0': []}, 'untouched': ['3']} + + compare_generalizations(gener, expected_generalizations) + + check_features(features, expected_generalizations, transformed, x_train) + ncp = gen.ncp.transform_score + check_ncp(ncp, expected_generalizations) + + rel_accuracy = model.score(transformed, predictions) + assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF) + + def test_anonymize_pandas_one_hot(): features = ["age", "gender_M", "gender_F", "height"] x_train = np.array([[23, 0, 1, 165],