diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index ec03f70..19869e8 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -207,8 +207,10 @@ class PyTorchClassifier(PyTorchModel): def score(self, test_data: Dataset, **kwargs): """ Score the model using test data. - :param test_data: Test data. :type train_data: `Dataset` + :return: the score as float (between 0 and 1) """ - pass + y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes) + predicted = self.predict(test_data) + return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]