One more test + fixes

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-11-15 08:54:27 -05:00
parent e7ee42fdc8
commit 904462a6a8
2 changed files with 45 additions and 4 deletions

View file

@ -346,8 +346,9 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# if accuracy above threshold, improve generalization
if accuracy > self.target_accuracy:
print('Improving generalizations')
self._level = 1
self._level = 0
while accuracy > self.target_accuracy:
self._level += 1
cells_previous_iter = self.cells
generalization_prev_iter = self._generalizations
cells_by_id_prev = self._cells_by_id
@ -373,7 +374,6 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
break
else:
print('Pruned tree to level: %d, new relative accuracy: %f' % (self._level, accuracy))
self._level += 1
# if accuracy below threshold, improve accuracy by removing features from generalization
elif accuracy < self.target_accuracy:

View file

@ -181,8 +181,8 @@ def compare_generalizations(gener, expected_generalizations):
== set(gener['range_representatives'][key]))
if 'category_representatives' in expected_generalizations:
for key in expected_generalizations['category_representatives']:
assert (set([frozenset(sl) for sl in expected_generalizations['category_representatives'][key]])
== set([frozenset(sl) for sl in gener['category_representatives'][key]]))
assert (set(expected_generalizations['category_representatives'][key])
== set(gener['category_representatives'][key]))
def check_features(features, expected_generalizations, transformed, x, pandas=False):
@ -961,6 +961,47 @@ def test_minimizer_ndarray_one_hot():
assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF)
def test_minimizer_ndarray_one_hot_gen():
x_train = np.array([[23, 0, 1, 165],
[45, 0, 1, 158],
[56, 1, 0, 123],
[67, 0, 1, 154],
[45, 1, 0, 149],
[42, 1, 0, 166],
[73, 0, 1, 172],
[94, 0, 1, 168],
[69, 0, 1, 175],
[24, 1, 0, 181],
[18, 1, 0, 190]])
y_train = np.array([1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0])
model = DecisionTreeClassifier()
model.fit(x_train, y_train)
predictions = model.predict(x_train)
features = ['0', '1', '2', '3']
QI = [0, 1, 2]
QI_slices = [[1, 2]]
target_accuracy = 0.2
gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy, feature_slices=QI_slices,
features_to_minimize=QI)
gen.fit(dataset=ArrayDataset(x_train, predictions))
transformed = gen.transform(dataset=ArrayDataset(x_train))
gener = gen.generalizations
expected_generalizations = {'categories': {'1': [[0, 1]], '2': [[0, 1]]},
'category_representatives': {'1': [0], '2': [1]},
'range_representatives': {'0': []}, 'ranges': {'0': []}, 'untouched': ['3']}
compare_generalizations(gener, expected_generalizations)
check_features(features, expected_generalizations, transformed, x_train)
ncp = gen.ncp.transform_score
check_ncp(ncp, expected_generalizations)
rel_accuracy = model.score(transformed, predictions)
assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF)
def test_anonymize_pandas_one_hot():
features = ["age", "gender_M", "gender_F", "height"]
x_train = np.array([[23, 0, 1, 165],