diff --git a/apt/utils/models/model.py b/apt/utils/models/model.py index d9811e1..06cae10 100644 --- a/apt/utils/models/model.py +++ b/apt/utils/models/model.py @@ -316,10 +316,20 @@ class BlackboxClassifierPredictions(BlackboxClassifier): self._nb_classes = get_nb_classes(y_pred) self._input_shape = x_pred.shape[1:] + self._x_pred = x_pred + self._y_pred = y_pred predict_fn = (x_pred, y_pred) self._art_model = BlackBoxClassifier(predict_fn, self._input_shape, self._nb_classes, fuzzy_float_compare=True, preprocessing=None) + def get_predictions(self) -> Tuple[OUTPUT_DATA_ARRAY_TYPE, OUTPUT_DATA_ARRAY_TYPE]: + """ + Return all the data for which the model contains predictions. + + :return: Tuple containing data and predictions as numpy arrays. + """ + return self._x_pred, self._y_pred + class BlackboxClassifierPredictFunction(BlackboxClassifier): """ diff --git a/tests/test_model.py b/tests/test_model.py index 821c776..21c8fff 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,6 +12,8 @@ from sklearn.ensemble import RandomForestClassifier from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Input +from art.utils import check_and_transform_label_format + from art.utils import to_categorical @@ -158,6 +160,10 @@ def test_blackbox_classifier_no_test(): score = model.score(train) assert (score == 1.0) + predictions_x, predictions_y = model.get_predictions() + assert np.array_equal(predictions_x, x_train) + assert np.array_equal(predictions_y, check_and_transform_label_format(y_train, nb_classes=3)) + def test_blackbox_classifier_no_train(): (_, _), (x_test, y_test) = dataset_utils.get_iris_dataset_np() @@ -171,6 +177,10 @@ def test_blackbox_classifier_no_train(): score = model.score(test) assert (score == 1.0) + predictions_x, predictions_y = model.get_predictions() + assert np.array_equal(predictions_x, x_test) + assert np.array_equal(predictions_y, check_and_transform_label_format(y_test, nb_classes=3)) + def test_blackbox_classifier_no_test_y(): (x_train, y_train), (x_test, _) = dataset_utils.get_iris_dataset_np()