ModelOutputType is now a Flag instead of regular enum. Combinations of the base flags are provided for all of the previous output types for convenience. All checks in the code now use the basic flags and not the complex types.

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-07-03 13:29:37 +03:00
parent 2895b40f05
commit 367cae679b
10 changed files with 126 additions and 100 deletions

View file

@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, TensorDataset
from art.utils import check_and_transform_label_format
from apt.utils.datasets.datasets import PytorchData, DatasetWithPredictions, ArrayDataset
from apt.utils.models import Model, ModelOutputType, is_multi_label, is_multi_label_binary
from apt.utils.models import Model, ModelOutputType, is_multi_label, is_multi_label_binary, is_binary
from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE, array2numpy
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
@ -56,8 +56,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
super().__init__(model, loss, input_shape, nb_classes, optimizer, use_amp, opt_level, loss_scale,
channels_first, clip_values, preprocessing_defences, postprocessing_defences, preprocessing,
device_type)
self._is_single_binary = (output_type == ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_BINARY_PROBABILITIES
or output_type == ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_BINARY_LOGITS)
self._is_single_binary = not is_multi_label(output_type) and is_binary(output_type)
self._is_multi_label = is_multi_label(output_type)
self._is_multi_label_binary = is_multi_label_binary(output_type)