diff --git a/apt/minimization/minimizer.py b/apt/minimization/minimizer.py index 702b799..1e15d3d 100644 --- a/apt/minimization/minimizer.py +++ b/apt/minimization/minimizer.py @@ -93,8 +93,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM self.train_only_features_to_minimize = train_only_features_to_minimize self.is_regression = is_regression self.encoder = encoder - # self.generalize_using_transform = generalize_using_transform - self.generalize_using_transform = False + self.generalize_using_transform = generalize_using_transform self._ncp = 0.0 self._feature_data = {} self._categorical_values = {} diff --git a/tests/test_minimizer.py b/tests/test_minimizer.py index 6cc5197..8e5a6cc 100644 --- a/tests/test_minimizer.py +++ b/tests/test_minimizer.py @@ -75,7 +75,7 @@ def test_minimizer_params_not_transform(data): model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES) model.fit(ArrayDataset(X, y)) - gen = GeneralizeToRepresentative(model, cells=cells) + gen = GeneralizeToRepresentative(model, cells=cells, generalize_using_transform=False) gen.calculate_ncp(X) ncp = gen.ncp assert (ncp > 0.0) @@ -158,7 +158,7 @@ def test_minimizer_fit_not_transform(data): if predictions.shape[1] > 1: predictions = np.argmax(predictions, axis=1) target_accuracy = 0.5 - gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy) + gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy, generalize_using_transform=False) train_dataset = ArrayDataset(X, predictions, features_names=features) gen.fit(dataset=train_dataset) @@ -1043,3 +1043,32 @@ def test_untouched(): assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]]) == set([frozenset(sl) for sl in gener['categories'][key]])) assert (set(expected_generalizations['untouched']) == set(gener['untouched'])) + + +def test_errors(): + features = ['age', 'height'] + X = np.array([[23, 165], + [45, 158], + [56, 123], + [67, 154], + [45, 149], + [42, 166], + [73, 172], + [94, 168], + [69, 175], + [24, 181], + [18, 190]]) + y = np.array([1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0]) + base_est = DecisionTreeClassifier(random_state=0, min_samples_split=2, + min_samples_leaf=1) + model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES) + model.fit(ArrayDataset(X, y)) + ad = ArrayDataset(X) + predictions = model.predict(ad) + if predictions.shape[1] > 1: + predictions = np.argmax(predictions, axis=1) + gen = GeneralizeToRepresentative(model, generalize_using_transform=False) + train_dataset = ArrayDataset(X, predictions, features_names=features) + gen.fit(dataset=train_dataset) + with pytest.raises(ValueError): + gen.transform(X)