diff --git a/apt/utils/models/model.py b/apt/utils/models/model.py index 902a22c..e81bc89 100644 --- a/apt/utils/models/model.py +++ b/apt/utils/models/model.py @@ -340,4 +340,11 @@ class BlackboxClassifierPredictFunction(BlackboxClassifier): super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs) self._nb_classes = nb_classes self._input_shape = input_shape - self._art_model = BlackBoxClassifier(model, self._input_shape, self._nb_classes, preprocessing=None) + + def predict_wrapper(x): + predictions = self.model(x) + if not is_one_hot(predictions): + predictions = check_and_transform_label_format(predictions, nb_classes=nb_classes, return_one_hot=True) + return predictions + + self._art_model = BlackBoxClassifier(predict_wrapper, self._input_shape, self._nb_classes, preprocessing=None) diff --git a/tests/test_model.py b/tests/test_model.py index 8b4769c..7ad260d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -199,6 +199,23 @@ def test_blackbox_classifier_predict(): assert (score == 1.0) +def test_blackbox_classifier_predict_scalar(): + def predict(x): + return np.array([[1.0] for i in range(x.shape[0])]) + + (x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np() + y_train = np.array([[0, 1, 0] for i in range(105)]) + + train = ArrayDataset(x_train, y_train) + + model = BlackboxClassifierPredictFunction(predict, ModelOutputType.CLASSIFIER_SCALAR, (4,), 3) + pred = model.predict(train) + assert (pred.shape[0] == x_train.shape[0]) + + score = model.score(train) + assert (score == 1.0) + + def test_is_one_hot(): (_, y_train), (_, _) = dataset_utils.get_iris_dataset_np() @@ -206,6 +223,7 @@ def test_is_one_hot(): assert (not is_one_hot(y_train.reshape(-1,1))) assert (is_one_hot(to_categorical(y_train))) + def test_get_nb_classes(): (_, y_train), (_, y_test) = dataset_utils.get_iris_dataset_np()