mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-10 12:32:38 +02:00
enhance calculation of nb classes + tests (#45)
* update get_nb_classes method to handle 1-hot and scalar input
This commit is contained in:
parent
50317a8d67
commit
e25e58b253
4 changed files with 62 additions and 18 deletions
|
|
@ -8,6 +8,29 @@ from art.estimators.classification import BlackBoxClassifier
|
|||
from art.utils import check_and_transform_label_format
|
||||
|
||||
|
||||
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
|
||||
return len(y.shape) == 2 and y.shape[1] > 1
|
||||
|
||||
|
||||
def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
|
||||
"""
|
||||
Get the number of classes from an array of labels
|
||||
|
||||
:param y: the labels
|
||||
:type y: numpy array
|
||||
:return: the number of classes as integer
|
||||
"""
|
||||
if y is None:
|
||||
return 0
|
||||
|
||||
if type(y) != np.ndarray:
|
||||
raise ValueError("Input should be numpy array")
|
||||
|
||||
if is_one_hot(y):
|
||||
return y.shape[1]
|
||||
else:
|
||||
return int(np.max(y) + 1)
|
||||
|
||||
class ModelOutputType(Enum):
|
||||
CLASSIFIER_PROBABILITIES = auto() # vector of probabilities
|
||||
CLASSIFIER_LOGITS = auto() # vector of logits
|
||||
|
|
@ -117,18 +140,7 @@ class Model(metaclass=ABCMeta):
|
|||
"""
|
||||
return self._unlimited_queries
|
||||
|
||||
def get_nb_classes(self, y: OUTPUT_DATA_ARRAY_TYPE) -> int:
|
||||
"""
|
||||
Get the number of classes from an array of labels
|
||||
|
||||
:param y: the labels
|
||||
:type y: numpy array
|
||||
:return: the number of classes as integer
|
||||
"""
|
||||
if len(y.shape) == 1:
|
||||
return np.max(y) + 1
|
||||
else:
|
||||
return y.shape[1]
|
||||
|
||||
|
||||
class BlackboxClassifier(Model):
|
||||
|
|
@ -233,11 +245,11 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
|
|||
y_test_pred = model.get_test_labels()
|
||||
|
||||
if y_train_pred is not None and len(y_train_pred.shape) == 1:
|
||||
self._nb_classes = self.get_nb_classes(y_train_pred)
|
||||
self._nb_classes = get_nb_classes(y_train_pred)
|
||||
y_train_pred = check_and_transform_label_format(y_train_pred, nb_classes=self._nb_classes)
|
||||
if y_test_pred is not None and len(y_test_pred.shape) == 1:
|
||||
if self._nb_classes is None:
|
||||
self._nb_classes = self.get_nb_classes(y_test_pred)
|
||||
self._nb_classes = get_nb_classes(y_test_pred)
|
||||
y_test_pred = check_and_transform_label_format(y_test_pred, nb_classes=self._nb_classes)
|
||||
|
||||
if x_train_pred is not None and y_train_pred is not None and x_test_pred is not None and y_test_pred is not None:
|
||||
|
|
@ -255,7 +267,7 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
|
|||
else:
|
||||
raise NotImplementedError("Invalid data - None")
|
||||
|
||||
self._nb_classes = self.get_nb_classes(y_pred)
|
||||
self._nb_classes = get_nb_classes(y_pred)
|
||||
self._input_shape = x_pred.shape[1:]
|
||||
predict_fn = (x_pred, y_pred)
|
||||
self._art_model = BlackBoxClassifier(predict_fn, self._input_shape, self._nb_classes, fuzzy_float_compare=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue