From 1e3dd399d01f7085d3d239e581512920d760cc45 Mon Sep 17 00:00:00 2001 From: abigailt Date: Thu, 18 Jan 2024 17:50:20 +0200 Subject: [PATCH] Fix calculation of score for pytorch models when a single column of probabilities is used for binary classification Signed-off-by: abigailt --- apt/utils/models/pytorch_model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 6591779..f234311 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -408,9 +408,16 @@ class PyTorchClassifier(PyTorchModel): :type test_data: `PytorchData` :return: the score as float (between 0 and 1) """ - y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes) + y = test_data.get_labels() predicted = self.predict(test_data) - return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0] + # binary classification, single column of probabilities + if self._art_model.nb_classes == 2 and (len(predicted.shape) == 1 or predicted.shape[1] == 1): + if len(predicted.shape) > 1: + y = check_and_transform_label_format(y, self._art_model.nb_classes, return_one_hot=False) + return np.count_nonzero(y == (predicted > 0.5)) / predicted.shape[0] + else: + y = check_and_transform_label_format(y, self._art_model.nb_classes) + return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0] def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None): """