mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-07-02 16:01:00 +02:00
ModelOutputType is now a Flag instead of regular enum. Combinations of the base flags are provided for all of the previous output types for convenience. All checks in the code now use the basic flags and not the complex types.
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
2895b40f05
commit
367cae679b
10 changed files with 126 additions and 100 deletions
|
|
@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, TensorDataset
|
|||
|
||||
from art.utils import check_and_transform_label_format
|
||||
from apt.utils.datasets.datasets import PytorchData, DatasetWithPredictions, ArrayDataset
|
||||
from apt.utils.models import Model, ModelOutputType, is_multi_label, is_multi_label_binary
|
||||
from apt.utils.models import Model, ModelOutputType, is_multi_label, is_multi_label_binary, is_binary
|
||||
from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE, array2numpy
|
||||
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
|
||||
|
||||
|
|
@ -56,8 +56,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
super().__init__(model, loss, input_shape, nb_classes, optimizer, use_amp, opt_level, loss_scale,
|
||||
channels_first, clip_values, preprocessing_defences, postprocessing_defences, preprocessing,
|
||||
device_type)
|
||||
self._is_single_binary = (output_type == ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_BINARY_PROBABILITIES
|
||||
or output_type == ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_BINARY_LOGITS)
|
||||
self._is_single_binary = not is_multi_label(output_type) and is_binary(output_type)
|
||||
self._is_multi_label = is_multi_label(output_type)
|
||||
self._is_multi_label_binary = is_multi_label_binary(output_type)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue