mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-29 23:06:21 +02:00
Wrap predict method in BlackBoxClassifierPredictMethod to avoid exception in ART when supplied method returns scalars
This commit is contained in:
parent
1cc73b3da1
commit
a7d156660e
2 changed files with 26 additions and 1 deletions
|
|
@ -340,4 +340,11 @@ class BlackboxClassifierPredictFunction(BlackboxClassifier):
|
|||
super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs)
|
||||
self._nb_classes = nb_classes
|
||||
self._input_shape = input_shape
|
||||
self._art_model = BlackBoxClassifier(model, self._input_shape, self._nb_classes, preprocessing=None)
|
||||
|
||||
def predict_wrapper(x):
|
||||
predictions = self.model(x)
|
||||
if not is_one_hot(predictions):
|
||||
predictions = check_and_transform_label_format(predictions, nb_classes=nb_classes, return_one_hot=True)
|
||||
return predictions
|
||||
|
||||
self._art_model = BlackBoxClassifier(predict_wrapper, self._input_shape, self._nb_classes, preprocessing=None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue