mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
Wrap predict method in BlackBoxClassifierPredictMethod to avoid exception in ART when supplied method returns scalars
This commit is contained in:
parent
1cc73b3da1
commit
a7d156660e
2 changed files with 26 additions and 1 deletions
|
|
@ -340,4 +340,11 @@ class BlackboxClassifierPredictFunction(BlackboxClassifier):
|
|||
super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs)
|
||||
self._nb_classes = nb_classes
|
||||
self._input_shape = input_shape
|
||||
self._art_model = BlackBoxClassifier(model, self._input_shape, self._nb_classes, preprocessing=None)
|
||||
|
||||
def predict_wrapper(x):
|
||||
predictions = self.model(x)
|
||||
if not is_one_hot(predictions):
|
||||
predictions = check_and_transform_label_format(predictions, nb_classes=nb_classes, return_one_hot=True)
|
||||
return predictions
|
||||
|
||||
self._art_model = BlackBoxClassifier(predict_wrapper, self._input_shape, self._nb_classes, preprocessing=None)
|
||||
|
|
|
|||
|
|
@ -199,6 +199,23 @@ def test_blackbox_classifier_predict():
|
|||
assert (score == 1.0)
|
||||
|
||||
|
||||
def test_blackbox_classifier_predict_scalar():
|
||||
def predict(x):
|
||||
return np.array([[1.0] for i in range(x.shape[0])])
|
||||
|
||||
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
|
||||
y_train = np.array([[0, 1, 0] for i in range(105)])
|
||||
|
||||
train = ArrayDataset(x_train, y_train)
|
||||
|
||||
model = BlackboxClassifierPredictFunction(predict, ModelOutputType.CLASSIFIER_SCALAR, (4,), 3)
|
||||
pred = model.predict(train)
|
||||
assert (pred.shape[0] == x_train.shape[0])
|
||||
|
||||
score = model.score(train)
|
||||
assert (score == 1.0)
|
||||
|
||||
|
||||
def test_is_one_hot():
|
||||
(_, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
|
||||
|
||||
|
|
@ -206,6 +223,7 @@ def test_is_one_hot():
|
|||
assert (not is_one_hot(y_train.reshape(-1,1)))
|
||||
assert (is_one_hot(to_categorical(y_train)))
|
||||
|
||||
|
||||
def test_get_nb_classes():
|
||||
(_, y_train), (_, y_test) = dataset_utils.get_iris_dataset_np()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue