Fix misclassification of categorical features with no generalizations (now appear under the 'untouched' category)

This commit is contained in:
abigailt 2022-05-19 16:41:31 +03:00 committed by abigailgold
parent fe676fa426
commit 186f11eaaf
2 changed files with 40 additions and 3 deletions

View file

@ -62,9 +62,9 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
cells: Optional[list] = None, categorical_features: Optional[Union[np.ndarray, list]] = None,
features_to_minimize: Optional[Union[np.ndarray, list]] = None, train_only_QI: Optional[bool] = True,
is_regression: Optional[bool] = False):
if issubclass(estimator.__class__, Model):
self.estimator = estimator
else:
self.estimator = estimator
if estimator is not None and not issubclass(estimator.__class__, Model):
if is_regression:
self.estimator = SklearnRegressor(estimator)
else:
@ -832,6 +832,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._generalizations = {'ranges': GeneralizeToRepresentative._calculate_ranges(self.cells),
'categories': GeneralizeToRepresentative._calculate_categories(self.cells),
'untouched': GeneralizeToRepresentative._calculate_untouched(self.cells)}
self._remove_categorical_untouched(self._generalizations)
def _find_range_count(self, samples, ranges):
samples_df = pd.DataFrame(samples, columns=self._categorical_data.columns)
@ -988,3 +989,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
del cell['categories'][feature]
cell['untouched'].append(feature)
cells_by_id[cell['id']] = cell.copy()
@staticmethod
def _remove_categorical_untouched(generalizations):
to_remove = []
for feature in generalizations['categories'].keys():
category_sizes = [len(g) if len(g) > 1 else 0 for g in generalizations['categories'][feature]]
if sum(category_sizes) == 0:
if 'untouched' not in generalizations:
generalizations['untouched'] = []
generalizations['untouched'].append(feature)
to_remove.append(feature)
for feature in to_remove:
del generalizations['categories'][feature]

View file

@ -937,3 +937,25 @@ def test_blackbox_model():
if len(expected_generalizations['ranges'].keys()) > 0 or len(expected_generalizations['categories'].keys()) > 0:
assert (ncp > 0)
assert (((transformed[indexes]) != (X[indexes])).any())
def test_untouched():
cells = [{"id": 1, "ranges": {"age": {"start": None, "end": 38}}, "label": 0,
'categories': {'gender': ['male']}, "representative": {"age": 26, "height": 149}},
{"id": 2, "ranges": {"age": {"start": 39, "end": None}}, "label": 1,
'categories': {'gender': ['female']}, "representative": {"age": 58, "height": 163}},
{"id": 3, "ranges": {"age": {"start": None, "end": 38}}, "label": 0,
'categories': {'gender': ['male']}, "representative": {"age": 31, "height": 184}},
{"id": 4, "ranges": {"age": {"start": 39, "end": None}}, "label": 1,
'categories': {'gender': ['male', 'female']}, "representative": {"age": 45, "height": 176}}
]
gen = GeneralizeToRepresentative(cells=cells)
gen._calculate_generalizations()
gener = gen.generalizations
expected_generalizations = {'ranges': {'age': [38, 39]}, 'categories': {}, 'untouched': ['gender']}
for key in expected_generalizations['ranges']:
assert (set(expected_generalizations['ranges'][key]) == set(gener['ranges'][key]))
for key in expected_generalizations['categories']:
assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]]) ==
set([frozenset(sl) for sl in gener['categories'][key]]))
assert (set(expected_generalizations['untouched']) == set(gener['untouched']))