From e0385b0d047dc46676891162a901f6b8f9e63ca3 Mon Sep 17 00:00:00 2001 From: olasaadi Date: Thu, 19 May 2022 04:40:02 +0300 Subject: [PATCH] score --- apt/utils/models/pytorch_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index ec03f70..19869e8 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -207,8 +207,10 @@ class PyTorchClassifier(PyTorchModel): def score(self, test_data: Dataset, **kwargs): """ Score the model using test data. - :param test_data: Test data. :type train_data: `Dataset` + :return: the score as float (between 0 and 1) """ - pass + y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes) + predicted = self.predict(test_data) + return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]