Initial support+test for pytorch multi-label binary classifier

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-02-19 14:16:03 +02:00
parent f197199e54
commit 79534b69db
4 changed files with 193 additions and 14 deletions

View file

@ -1,6 +1,6 @@
from apt.utils.models.model import Model, BlackboxClassifier, ModelOutputType, ScoringMethod, \
BlackboxClassifierPredictions, BlackboxClassifierPredictFunction, get_nb_classes, is_one_hot, \
check_correct_model_output
check_correct_model_output, is_multi_label, is_multi_label_binary
from apt.utils.models.sklearn_model import SklearnModel, SklearnClassifier, SklearnRegressor
from apt.utils.models.keras_model import KerasClassifier, KerasRegressor
from apt.utils.models.xgboost_model import XGBoostClassifier

View file

@ -3,14 +3,14 @@ import os
import shutil
import logging
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, List, Callable
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from art.utils import check_and_transform_label_format
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import Model, ModelOutputType
from apt.utils.models import Model, ModelOutputType, is_multi_label, is_multi_label_binary
from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
@ -30,16 +30,45 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
Extension for Pytorch ART model
"""
def __init__(
self,
model: "torch.nn.Module",
loss: "torch.nn.modules.loss._Loss",
input_shape: Tuple[int, ...],
nb_classes: int,
optimizer: Optional["torch.optim.Optimizer"] = None, # type: ignore
use_amp: bool = False,
opt_level: str = "O1",
loss_scale: Optional[Union[float, str]] = "dynamic",
channels_first: bool = True,
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
preprocessing: "PREPROCESSING_TYPE" = (0.0, 1.0),
device_type: str = "gpu",
):
super().__init__(model, loss, input_shape, nb_classes, optimizer, use_amp, opt_level, loss_scale,
channels_first, clip_values, preprocessing_defences, postprocessing_defences, preprocessing,
device_type)
self._is_single_binary = None
self._is_multi_label = None
self._is_multi_label_binary = None
def get_step_correct(self, outputs, targets) -> int:
"""
Get number of correctly classified labels.
"""
# here everything is torch tensors
if len(outputs) != len(targets):
raise ValueError("outputs and targets should be the same length.")
if self.nb_classes > 1:
return int(torch.sum(torch.argmax(outputs, axis=-1) == targets).item())
if self._is_single_binary:
return int(torch.sum(torch.round(outputs) == targets).item())
elif self._is_multi_label:
if self._is_multi_label_binary:
outputs = torch.round(outputs)
return int(torch.sum(targets == outputs).item())
else:
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
return int(torch.sum(torch.argmax(outputs, axis=-1) == targets).item())
def _eval(self, loader: DataLoader):
"""
@ -93,6 +122,10 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
"""
self._is_single_binary = self.nb_classes == 2 and (len(y.shape) == 1 or y.shape[1] == 1)
self._is_multi_label = is_multi_label(y)
self._is_multi_label_binary = is_multi_label_binary(y)
# Put the model in the training mode
self._model.train()
@ -400,24 +433,43 @@ class PyTorchClassifier(PyTorchModel):
"""
return self._art_model.predict(x.get_samples(), **kwargs)
def score(self, test_data: PytorchData, **kwargs):
def score(self, test_data: PytorchData, binary_threshold: Optional[float] = 0.5,
apply_non_linearity: Optional[Callable] = None, **kwargs):
"""
Score the model using test data.
:param test_data: Test data.
:type test_data: `PytorchData`
:param binary_threshold: The threshold to use on binary classification probabilities to assign the positive
class.
:type binary_threshold: float, optional. Default is 0.5.
:param apply_non_linearity: A non-linear function to apply to the result of the 'predict' method, in case the
model outputs logits (e.g., sigmoid).
:type apply_non_linearity: Callable, should be possible to apply directly to the numpy output of the 'predict'
method, optional.
:return: the score as float (between 0 and 1)
"""
# numpy arrays
y = test_data.get_labels()
predicted = self.predict(test_data)
if apply_non_linearity:
predicted = apply_non_linearity(predicted)
# binary classification, single column of probabilities
if self._art_model.nb_classes == 2 and (len(predicted.shape) == 1 or predicted.shape[1] == 1):
if len(predicted.shape) > 1:
y = check_and_transform_label_format(y, self._art_model.nb_classes, return_one_hot=False)
return np.count_nonzero(y == (predicted > 0.5)) / predicted.shape[0]
return np.count_nonzero(y == (predicted > binary_threshold)) / predicted.shape[0]
# multi column
else:
y = check_and_transform_label_format(y, self._art_model.nb_classes)
return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]
if not is_multi_label(y):
y = check_and_transform_label_format(y, self._art_model.nb_classes)
return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]
else:
if is_multi_label_binary(y):
predicted[predicted < binary_threshold] = 0
predicted[predicted >= binary_threshold] = 1
return np.count_nonzero(y == predicted) / (predicted.shape[0] * y.shape[1])
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
"""