diff --git a/apt/minimization/minimizer.py b/apt/minimization/minimizer.py index df70a3a..ebac318 100644 --- a/apt/minimization/minimizer.py +++ b/apt/minimization/minimizer.py @@ -62,9 +62,9 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM cells: Optional[list] = None, categorical_features: Optional[Union[np.ndarray, list]] = None, features_to_minimize: Optional[Union[np.ndarray, list]] = None, train_only_QI: Optional[bool] = True, is_regression: Optional[bool] = False): - if issubclass(estimator.__class__, Model): - self.estimator = estimator - else: + + self.estimator = estimator + if estimator is not None and not issubclass(estimator.__class__, Model): if is_regression: self.estimator = SklearnRegressor(estimator) else: @@ -832,6 +832,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM self._generalizations = {'ranges': GeneralizeToRepresentative._calculate_ranges(self.cells), 'categories': GeneralizeToRepresentative._calculate_categories(self.cells), 'untouched': GeneralizeToRepresentative._calculate_untouched(self.cells)} + self._remove_categorical_untouched(self._generalizations) def _find_range_count(self, samples, ranges): samples_df = pd.DataFrame(samples, columns=self._categorical_data.columns) @@ -988,3 +989,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM del cell['categories'][feature] cell['untouched'].append(feature) cells_by_id[cell['id']] = cell.copy() + + @staticmethod + def _remove_categorical_untouched(generalizations): + to_remove = [] + for feature in generalizations['categories'].keys(): + category_sizes = [len(g) if len(g) > 1 else 0 for g in generalizations['categories'][feature]] + if sum(category_sizes) == 0: + if 'untouched' not in generalizations: + generalizations['untouched'] = [] + generalizations['untouched'].append(feature) + to_remove.append(feature) + + for feature in to_remove: + del generalizations['categories'][feature] diff --git a/tests/test_minimizer.py b/tests/test_minimizer.py index 181755b..6adad90 100644 --- a/tests/test_minimizer.py +++ b/tests/test_minimizer.py @@ -937,3 +937,25 @@ def test_blackbox_model(): if len(expected_generalizations['ranges'].keys()) > 0 or len(expected_generalizations['categories'].keys()) > 0: assert (ncp > 0) assert (((transformed[indexes]) != (X[indexes])).any()) + + +def test_untouched(): + cells = [{"id": 1, "ranges": {"age": {"start": None, "end": 38}}, "label": 0, + 'categories': {'gender': ['male']}, "representative": {"age": 26, "height": 149}}, + {"id": 2, "ranges": {"age": {"start": 39, "end": None}}, "label": 1, + 'categories': {'gender': ['female']}, "representative": {"age": 58, "height": 163}}, + {"id": 3, "ranges": {"age": {"start": None, "end": 38}}, "label": 0, + 'categories': {'gender': ['male']}, "representative": {"age": 31, "height": 184}}, + {"id": 4, "ranges": {"age": {"start": 39, "end": None}}, "label": 1, + 'categories': {'gender': ['male', 'female']}, "representative": {"age": 45, "height": 176}} + ] + gen = GeneralizeToRepresentative(cells=cells) + gen._calculate_generalizations() + gener = gen.generalizations + expected_generalizations = {'ranges': {'age': [38, 39]}, 'categories': {}, 'untouched': ['gender']} + for key in expected_generalizations['ranges']: + assert (set(expected_generalizations['ranges'][key]) == set(gener['ranges'][key])) + for key in expected_generalizations['categories']: + assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]]) == + set([frozenset(sl) for sl in gener['categories'][key]])) + assert (set(expected_generalizations['untouched']) == set(gener['untouched']))