This commit is contained in:
olasaadi 2022-05-19 04:40:02 +03:00
parent 7539ca0ead
commit e0385b0d04

View file

@ -207,8 +207,10 @@ class PyTorchClassifier(PyTorchModel):
def score(self, test_data: Dataset, **kwargs):
"""
Score the model using test data.
:param test_data: Test data.
:type train_data: `Dataset`
:return: the score as float (between 0 and 1)
"""
pass
y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes)
predicted = self.predict(test_data)
return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]