mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
enhance calculation of nb classes + tests (#45)
* update get_nb_classes method to handle 1-hot and scalar input
This commit is contained in:
parent
50317a8d67
commit
e25e58b253
4 changed files with 62 additions and 18 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from apt.utils.models.model import Model, BlackboxClassifier, ModelOutputType, ScoringMethod, \
|
||||
BlackboxClassifierPredictions, BlackboxClassifierPredictFunction
|
||||
BlackboxClassifierPredictions, BlackboxClassifierPredictFunction, get_nb_classes, is_one_hot
|
||||
from apt.utils.models.sklearn_model import SklearnModel, SklearnClassifier, SklearnRegressor
|
||||
from apt.utils.models.keras_model import KerasClassifier
|
||||
|
|
|
|||
|
|
@ -8,6 +8,29 @@ from art.estimators.classification import BlackBoxClassifier
|
|||
from art.utils import check_and_transform_label_format
|
||||
|
||||
|
||||
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
|
||||
return len(y.shape) == 2 and y.shape[1] > 1
|
||||
|
||||
|
||||
def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
|
||||
"""
|
||||
Get the number of classes from an array of labels
|
||||
|
||||
:param y: the labels
|
||||
:type y: numpy array
|
||||
:return: the number of classes as integer
|
||||
"""
|
||||
if y is None:
|
||||
return 0
|
||||
|
||||
if type(y) != np.ndarray:
|
||||
raise ValueError("Input should be numpy array")
|
||||
|
||||
if is_one_hot(y):
|
||||
return y.shape[1]
|
||||
else:
|
||||
return int(np.max(y) + 1)
|
||||
|
||||
class ModelOutputType(Enum):
|
||||
CLASSIFIER_PROBABILITIES = auto() # vector of probabilities
|
||||
CLASSIFIER_LOGITS = auto() # vector of logits
|
||||
|
|
@ -117,18 +140,7 @@ class Model(metaclass=ABCMeta):
|
|||
"""
|
||||
return self._unlimited_queries
|
||||
|
||||
def get_nb_classes(self, y: OUTPUT_DATA_ARRAY_TYPE) -> int:
|
||||
"""
|
||||
Get the number of classes from an array of labels
|
||||
|
||||
:param y: the labels
|
||||
:type y: numpy array
|
||||
:return: the number of classes as integer
|
||||
"""
|
||||
if len(y.shape) == 1:
|
||||
return np.max(y) + 1
|
||||
else:
|
||||
return y.shape[1]
|
||||
|
||||
|
||||
class BlackboxClassifier(Model):
|
||||
|
|
@ -233,11 +245,11 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
|
|||
y_test_pred = model.get_test_labels()
|
||||
|
||||
if y_train_pred is not None and len(y_train_pred.shape) == 1:
|
||||
self._nb_classes = self.get_nb_classes(y_train_pred)
|
||||
self._nb_classes = get_nb_classes(y_train_pred)
|
||||
y_train_pred = check_and_transform_label_format(y_train_pred, nb_classes=self._nb_classes)
|
||||
if y_test_pred is not None and len(y_test_pred.shape) == 1:
|
||||
if self._nb_classes is None:
|
||||
self._nb_classes = self.get_nb_classes(y_test_pred)
|
||||
self._nb_classes = get_nb_classes(y_test_pred)
|
||||
y_test_pred = check_and_transform_label_format(y_test_pred, nb_classes=self._nb_classes)
|
||||
|
||||
if x_train_pred is not None and y_train_pred is not None and x_test_pred is not None and y_test_pred is not None:
|
||||
|
|
@ -255,7 +267,7 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
|
|||
else:
|
||||
raise NotImplementedError("Invalid data - None")
|
||||
|
||||
self._nb_classes = self.get_nb_classes(y_pred)
|
||||
self._nb_classes = get_nb_classes(y_pred)
|
||||
self._input_shape = x_pred.shape[1:]
|
||||
predict_fn = (x_pred, y_pred)
|
||||
self._art_model = BlackBoxClassifier(predict_fn, self._input_shape, self._nb_classes, fuzzy_float_compare=True)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from sklearn.preprocessing import OneHotEncoder
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from apt.utils.models import Model, ModelOutputType
|
||||
from apt.utils.models import Model, ModelOutputType, get_nb_classes
|
||||
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
||||
|
||||
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
|
||||
|
|
@ -59,7 +59,7 @@ class SklearnClassifier(SklearnModel):
|
|||
:return: None
|
||||
"""
|
||||
y = train_data.get_labels()
|
||||
self.nb_classes = self.get_nb_classes(y)
|
||||
self.nb_classes = get_nb_classes(y)
|
||||
y_encoded = check_and_transform_label_format(y, nb_classes=self.nb_classes)
|
||||
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
import numpy as np
|
||||
|
||||
from apt.utils.models import SklearnClassifier, SklearnRegressor, ModelOutputType, KerasClassifier, \
|
||||
BlackboxClassifierPredictions, BlackboxClassifierPredictFunction
|
||||
BlackboxClassifierPredictions, BlackboxClassifierPredictFunction, is_one_hot, get_nb_classes
|
||||
from apt.utils.datasets import ArrayDataset, Data
|
||||
from apt.utils import dataset_utils
|
||||
|
||||
|
|
@ -13,6 +13,9 @@ from tensorflow.keras.models import Sequential
|
|||
from tensorflow.keras.layers import Dense, Input
|
||||
|
||||
|
||||
from art.utils import to_categorical
|
||||
|
||||
|
||||
def test_sklearn_classifier():
|
||||
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
|
||||
underlying_model = RandomForestClassifier()
|
||||
|
|
@ -181,4 +184,33 @@ def test_blackbox_classifier_predict():
|
|||
score = model.score(train)
|
||||
assert (0.0 <= score <= 1.0)
|
||||
|
||||
def test_is_one_hot():
|
||||
(_, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
|
||||
|
||||
assert (not is_one_hot(y_train))
|
||||
assert (not is_one_hot(y_train.reshape(-1,1)))
|
||||
assert (is_one_hot(to_categorical(y_train)))
|
||||
|
||||
def test_get_nb_classes():
|
||||
(_, y_train), (_, y_test) = dataset_utils.get_iris_dataset_np()
|
||||
|
||||
# shape: (x,) - not 1-hot
|
||||
nb_classes_test = get_nb_classes(y_test)
|
||||
nb_classes_train = get_nb_classes(y_train)
|
||||
assert (nb_classes_test == nb_classes_train)
|
||||
assert (nb_classes_test == 3)
|
||||
|
||||
# shape: (x,1) - not 1-hot
|
||||
nb_classes_test = get_nb_classes(y_test.reshape(-1,1))
|
||||
assert (nb_classes_test == 3)
|
||||
|
||||
# shape: (x,3) - 1-hot
|
||||
y = to_categorical(y_test)
|
||||
nb_classes = get_nb_classes(y)
|
||||
assert (nb_classes == 3)
|
||||
|
||||
# gaps: 1,2,4 (0,3 missing)
|
||||
y_test[y_test == 0] = 4
|
||||
nb_classes = get_nb_classes(y_test)
|
||||
assert (nb_classes == 5)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue