diff --git a/apt/minimization/minimizer.py b/apt/minimization/minimizer.py index ebac318..c975fdb 100644 --- a/apt/minimization/minimizer.py +++ b/apt/minimization/minimizer.py @@ -324,31 +324,34 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM print('Improving generalizations') level = 1 while accuracy > self.target_accuracy: - try: - cells_previous_iter = self.cells - generalization_prev_iter = self._generalizations - cells_by_id_prev = self._cells_by_id - nodes = self._get_nodes_level(level) - self._calculate_level_cells(level) - self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes) + cells_previous_iter = self.cells + generalization_prev_iter = self._generalizations + cells_by_id_prev = self._cells_by_id + nodes = self._get_nodes_level(level) - self._calculate_generalizations() - generalized = self._generalize(X_test, x_prepared_test, nodes, self.cells, - self._cells_by_id) - accuracy = self.estimator.score(ArrayDataset(preprocessor.transform(generalized), y_test)) - # if accuracy passed threshold roll back to previous iteration generalizations - if accuracy < self.target_accuracy: - self.cells = cells_previous_iter - self._generalizations = generalization_prev_iter - self._cells_by_id = cells_by_id_prev - break - else: - print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy)) - level += 1 - except Exception as e: + try: + self._calculate_level_cells(level) + except TypeError as e: print(e) break + self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes) + + self._calculate_generalizations() + generalized = self._generalize(X_test, x_prepared_test, nodes, self.cells, + self._cells_by_id) + accuracy = self.estimator.score(ArrayDataset(preprocessor.transform(generalized), y_test)) + # if accuracy passed threshold roll back to previous iteration generalizations + if accuracy < self.target_accuracy: + self.cells = cells_previous_iter + self._generalizations = generalization_prev_iter + self._cells_by_id = cells_by_id_prev + break + else: + print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy)) + level += 1 + + # if accuracy below threshold, improve accuracy by removing features from generalization elif accuracy < self.target_accuracy: print('Improving accuracy') @@ -569,7 +572,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM features = self._categorical_data.columns for cell in self.cells: new_cell = {'id': cell['id'], 'label': cell['label'], 'ranges': {}, 'categories': {}, 'hist': cell['hist'], - 'representative': None} + 'untouched': [], 'representative': None} for feature in features: if feature in self._one_hot_vector_features_to_features.keys(): # feature is categorical and should be mapped diff --git a/tests/test_minimizer.py b/tests/test_minimizer.py index 6adad90..dca3ebd 100644 --- a/tests/test_minimizer.py +++ b/tests/test_minimizer.py @@ -1,6 +1,8 @@ import pytest import numpy as np import pandas as pd +from numpy.testing import assert_almost_equal + from sklearn.compose import ColumnTransformer from sklearn.datasets import load_boston, load_diabetes @@ -912,13 +914,12 @@ def test_blackbox_model(): gen.fit(dataset=train_dataset) transformed = gen.transform(dataset=ad) gener = gen.generalizations - expected_generalizations = {'ranges': {'0': [], '1': [], '2': [4.849999904632568, 5.049999952316284], - '3': [0.7000000029802322, 1.600000023841858]}, + expected_generalizations = {'ranges': {'0': [], '1': [], '2': [4.849999904632568], '3': [0.7000000029802322]}, 'categories': {}, 'untouched': []} for key in expected_generalizations['ranges']: - assert (set(expected_generalizations['ranges'][key]) == set(gener['ranges'][key])) + assert_almost_equal(expected_generalizations['ranges'][key], gener['ranges'][key]) for key in expected_generalizations['categories']: assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]]) == set([frozenset(sl) for sl in gener['categories'][key]]))