generalize_using_transform=False supported

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-05-18 10:32:54 +03:00
parent 51340fa554
commit 4a252e65fe
2 changed files with 32 additions and 4 deletions

View file

@ -75,7 +75,7 @@ def test_minimizer_params_not_transform(data):
model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)
model.fit(ArrayDataset(X, y))
gen = GeneralizeToRepresentative(model, cells=cells)
gen = GeneralizeToRepresentative(model, cells=cells, generalize_using_transform=False)
gen.calculate_ncp(X)
ncp = gen.ncp
assert (ncp > 0.0)
@ -158,7 +158,7 @@ def test_minimizer_fit_not_transform(data):
if predictions.shape[1] > 1:
predictions = np.argmax(predictions, axis=1)
target_accuracy = 0.5
gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy)
gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy, generalize_using_transform=False)
train_dataset = ArrayDataset(X, predictions, features_names=features)
gen.fit(dataset=train_dataset)
@ -1043,3 +1043,32 @@ def test_untouched():
assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]])
== set([frozenset(sl) for sl in gener['categories'][key]]))
assert (set(expected_generalizations['untouched']) == set(gener['untouched']))
def test_errors():
features = ['age', 'height']
X = np.array([[23, 165],
[45, 158],
[56, 123],
[67, 154],
[45, 149],
[42, 166],
[73, 172],
[94, 168],
[69, 175],
[24, 181],
[18, 190]])
y = np.array([1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0])
base_est = DecisionTreeClassifier(random_state=0, min_samples_split=2,
min_samples_leaf=1)
model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)
model.fit(ArrayDataset(X, y))
ad = ArrayDataset(X)
predictions = model.predict(ad)
if predictions.shape[1] > 1:
predictions = np.argmax(predictions, axis=1)
gen = GeneralizeToRepresentative(model, generalize_using_transform=False)
train_dataset = ArrayDataset(X, predictions, features_names=features)
gen.fit(dataset=train_dataset)
with pytest.raises(ValueError):
gen.transform(X)