mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
Fix calculation of score for pytorch models when a single column of probabilities is used for binary classification (#91)
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
a8f5326572
commit
cb70ca10e6
1 changed files with 9 additions and 2 deletions
|
|
@ -408,9 +408,16 @@ class PyTorchClassifier(PyTorchModel):
|
|||
:type test_data: `PytorchData`
|
||||
:return: the score as float (between 0 and 1)
|
||||
"""
|
||||
y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes)
|
||||
y = test_data.get_labels()
|
||||
predicted = self.predict(test_data)
|
||||
return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]
|
||||
# binary classification, single column of probabilities
|
||||
if self._art_model.nb_classes == 2 and (len(predicted.shape) == 1 or predicted.shape[1] == 1):
|
||||
if len(predicted.shape) > 1:
|
||||
y = check_and_transform_label_format(y, self._art_model.nb_classes, return_one_hot=False)
|
||||
return np.count_nonzero(y == (predicted > 0.5)) / predicted.shape[0]
|
||||
else:
|
||||
y = check_and_transform_label_format(y, self._art_model.nb_classes)
|
||||
return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]
|
||||
|
||||
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue