enhance calculation of nb classes + tests (#45)

* update get_nb_classes method to handle 1-hot and scalar input
This commit is contained in:
Shlomit Shachor 2022-07-05 11:32:17 +03:00 committed by GitHub
parent 50317a8d67
commit e25e58b253
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 18 deletions

View file

@ -8,6 +8,29 @@ from art.estimators.classification import BlackBoxClassifier
from art.utils import check_and_transform_label_format
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
return len(y.shape) == 2 and y.shape[1] > 1
def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
"""
Get the number of classes from an array of labels
:param y: the labels
:type y: numpy array
:return: the number of classes as integer
"""
if y is None:
return 0
if type(y) != np.ndarray:
raise ValueError("Input should be numpy array")
if is_one_hot(y):
return y.shape[1]
else:
return int(np.max(y) + 1)
class ModelOutputType(Enum):
CLASSIFIER_PROBABILITIES = auto() # vector of probabilities
CLASSIFIER_LOGITS = auto() # vector of logits
@ -117,18 +140,7 @@ class Model(metaclass=ABCMeta):
"""
return self._unlimited_queries
def get_nb_classes(self, y: OUTPUT_DATA_ARRAY_TYPE) -> int:
"""
Get the number of classes from an array of labels
:param y: the labels
:type y: numpy array
:return: the number of classes as integer
"""
if len(y.shape) == 1:
return np.max(y) + 1
else:
return y.shape[1]
class BlackboxClassifier(Model):
@ -233,11 +245,11 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
y_test_pred = model.get_test_labels()
if y_train_pred is not None and len(y_train_pred.shape) == 1:
self._nb_classes = self.get_nb_classes(y_train_pred)
self._nb_classes = get_nb_classes(y_train_pred)
y_train_pred = check_and_transform_label_format(y_train_pred, nb_classes=self._nb_classes)
if y_test_pred is not None and len(y_test_pred.shape) == 1:
if self._nb_classes is None:
self._nb_classes = self.get_nb_classes(y_test_pred)
self._nb_classes = get_nb_classes(y_test_pred)
y_test_pred = check_and_transform_label_format(y_test_pred, nb_classes=self._nb_classes)
if x_train_pred is not None and y_train_pred is not None and x_test_pred is not None and y_test_pred is not None:
@ -255,7 +267,7 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
else:
raise NotImplementedError("Invalid data - None")
self._nb_classes = self.get_nb_classes(y_pred)
self._nb_classes = get_nb_classes(y_pred)
self._input_shape = x_pred.shape[1:]
predict_fn = (x_pred, y_pred)
self._art_model = BlackBoxClassifier(predict_fn, self._input_shape, self._nb_classes, fuzzy_float_compare=True)