Replace values in multi-column 1-hot encoded features instead of appending so that options are narrowed down

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-12-24 13:13:52 -05:00
parent 0e01e19e0c
commit f646109e84
2 changed files with 53 additions and 11 deletions

View file

@ -1060,6 +1060,52 @@ def test_minimizer_ndarray_one_hot_multi():
assert ((np.min(transformed_slice, axis=1) == 0).all())
def test_minimizer_ndarray_one_hot_multi2():
x_train = np.array([[0, 0, 1],
[0, 0, 1],
[0, 1, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0]])
y_train = np.array([1, 1, 2, 2, 0, 0])
model = DecisionTreeClassifier()
model.fit(x_train, y_train)
predictions = model.predict(x_train)
features = ['0', '1', '2']
QI = [0, 1, 2]
QI_slices = [[0, 1, 2]]
target_accuracy = 0.2
gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy, feature_slices=QI_slices,
features_to_minimize=QI)
gen.fit(dataset=ArrayDataset(x_train, predictions))
transformed = gen.transform(dataset=ArrayDataset(x_train))
gener = gen.generalizations
expected_generalizations = {'categories':
{'1': [[0, 1]], '2': [[0, 1]], '3': [[0, 1]], '4': [[0, 1]], '5': [[0, 1]]},
'category_representatives': {'1': [0], '2': [1], '3': [0], '4': [1], '5': [0]},
'range_representatives': {'0': []}, 'ranges': {'0': []}, 'untouched': ['6']}
compare_generalizations(gener, expected_generalizations)
check_features(features, expected_generalizations, transformed, x_train)
ncp = gen.ncp.transform_score
check_ncp(ncp, expected_generalizations)
rel_accuracy = model.score(transformed, predictions)
assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF)
transformed_slice = transformed[:, QI_slices[0]]
assert ((np.sum(transformed_slice, axis=1) == 1).all())
assert ((np.max(transformed_slice, axis=1) == 1).all())
assert ((np.min(transformed_slice, axis=1) == 0).all())
transformed_slice = transformed[:, QI_slices[1]]
assert ((np.sum(transformed_slice, axis=1) == 1).all())
assert ((np.max(transformed_slice, axis=1) == 1).all())
assert ((np.min(transformed_slice, axis=1) == 0).all())
def test_anonymize_pandas_one_hot():
features = ["age", "gender_M", "gender_F", "height"]
x_train = np.array([[23, 0, 1, 165],