mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-06 18:42:37 +02:00
score
This commit is contained in:
parent
7539ca0ead
commit
e0385b0d04
1 changed files with 4 additions and 2 deletions
|
|
@ -207,8 +207,10 @@ class PyTorchClassifier(PyTorchModel):
|
|||
def score(self, test_data: Dataset, **kwargs):
|
||||
"""
|
||||
Score the model using test data.
|
||||
|
||||
:param test_data: Test data.
|
||||
:type train_data: `Dataset`
|
||||
:return: the score as float (between 0 and 1)
|
||||
"""
|
||||
pass
|
||||
y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes)
|
||||
predicted = self.predict(test_data)
|
||||
return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue