diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 6591779..f234311 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -408,9 +408,16 @@ class PyTorchClassifier(PyTorchModel): :type test_data: `PytorchData` :return: the score as float (between 0 and 1) """ - y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes) + y = test_data.get_labels() predicted = self.predict(test_data) - return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0] + # binary classification, single column of probabilities + if self._art_model.nb_classes == 2 and (len(predicted.shape) == 1 or predicted.shape[1] == 1): + if len(predicted.shape) > 1: + y = check_and_transform_label_format(y, self._art_model.nb_classes, return_one_hot=False) + return np.count_nonzero(y == (predicted > 0.5)) / predicted.shape[0] + else: + y = check_and_transform_label_format(y, self._art_model.nb_classes) + return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0] def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None): """