Wrap predict method in BlackBoxClassifierPredictMethod to avoid exception in ART when supplied method returns scalars

This commit is contained in:
abigailt 2022-07-20 11:39:13 +03:00 committed by abigailgold
parent 1cc73b3da1
commit a7d156660e
2 changed files with 26 additions and 1 deletions

View file

@ -340,4 +340,11 @@ class BlackboxClassifierPredictFunction(BlackboxClassifier):
super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs)
self._nb_classes = nb_classes
self._input_shape = input_shape
self._art_model = BlackBoxClassifier(model, self._input_shape, self._nb_classes, preprocessing=None)
def predict_wrapper(x):
predictions = self.model(x)
if not is_one_hot(predictions):
predictions = check_and_transform_label_format(predictions, nb_classes=nb_classes, return_one_hot=True)
return predictions
self._art_model = BlackBoxClassifier(predict_wrapper, self._input_shape, self._nb_classes, preprocessing=None)