Wrap predict method in BlackBoxClassifierPredictMethod to avoid exception in ART when supplied method returns scalars

This commit is contained in:
abigailt 2022-07-20 11:39:13 +03:00 committed by abigailgold
parent 1cc73b3da1
commit a7d156660e
2 changed files with 26 additions and 1 deletions

View file

@ -199,6 +199,23 @@ def test_blackbox_classifier_predict():
assert (score == 1.0)
def test_blackbox_classifier_predict_scalar():
def predict(x):
return np.array([[1.0] for i in range(x.shape[0])])
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
y_train = np.array([[0, 1, 0] for i in range(105)])
train = ArrayDataset(x_train, y_train)
model = BlackboxClassifierPredictFunction(predict, ModelOutputType.CLASSIFIER_SCALAR, (4,), 3)
pred = model.predict(train)
assert (pred.shape[0] == x_train.shape[0])
score = model.score(train)
assert (score == 1.0)
def test_is_one_hot():
(_, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
@ -206,6 +223,7 @@ def test_is_one_hot():
assert (not is_one_hot(y_train.reshape(-1,1)))
assert (is_one_hot(to_categorical(y_train)))
def test_get_nb_classes():
(_, y_train), (_, y_test) = dataset_utils.get_iris_dataset_np()