Wrap predict method in BlackBoxClassifierPredictMethod to avoid exception in ART when supplied method returns scalars

This commit is contained in:
abigailt 2022-07-20 11:39:13 +03:00 committed by abigailgold
parent 1cc73b3da1
commit a7d156660e
2 changed files with 26 additions and 1 deletions

View file

@ -340,4 +340,11 @@ class BlackboxClassifierPredictFunction(BlackboxClassifier):
super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs)
self._nb_classes = nb_classes
self._input_shape = input_shape
self._art_model = BlackBoxClassifier(model, self._input_shape, self._nb_classes, preprocessing=None)
def predict_wrapper(x):
predictions = self.model(x)
if not is_one_hot(predictions):
predictions = check_and_transform_label_format(predictions, nb_classes=nb_classes, return_one_hot=True)
return predictions
self._art_model = BlackBoxClassifier(predict_wrapper, self._input_shape, self._nb_classes, preprocessing=None)

View file

@ -199,6 +199,23 @@ def test_blackbox_classifier_predict():
assert (score == 1.0)
def test_blackbox_classifier_predict_scalar():
def predict(x):
return np.array([[1.0] for i in range(x.shape[0])])
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
y_train = np.array([[0, 1, 0] for i in range(105)])
train = ArrayDataset(x_train, y_train)
model = BlackboxClassifierPredictFunction(predict, ModelOutputType.CLASSIFIER_SCALAR, (4,), 3)
pred = model.predict(train)
assert (pred.shape[0] == x_train.shape[0])
score = model.score(train)
assert (score == 1.0)
def test_is_one_hot():
(_, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
@ -206,6 +223,7 @@ def test_is_one_hot():
assert (not is_one_hot(y_train.reshape(-1,1)))
assert (is_one_hot(to_categorical(y_train)))
def test_get_nb_classes():
(_, y_train), (_, y_test) = dataset_utils.get_iris_dataset_np()