enhance calculation of nb classes + tests (#45)

* update get_nb_classes method to handle 1-hot and scalar input
This commit is contained in:
Shlomit Shachor 2022-07-05 11:32:17 +03:00 committed by GitHub
parent 50317a8d67
commit e25e58b253
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 18 deletions

View file

@ -2,7 +2,7 @@ import pytest
import numpy as np
from apt.utils.models import SklearnClassifier, SklearnRegressor, ModelOutputType, KerasClassifier, \
BlackboxClassifierPredictions, BlackboxClassifierPredictFunction
BlackboxClassifierPredictions, BlackboxClassifierPredictFunction, is_one_hot, get_nb_classes
from apt.utils.datasets import ArrayDataset, Data
from apt.utils import dataset_utils
@ -13,6 +13,9 @@ from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input
from art.utils import to_categorical
def test_sklearn_classifier():
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
underlying_model = RandomForestClassifier()
@ -181,4 +184,33 @@ def test_blackbox_classifier_predict():
score = model.score(train)
assert (0.0 <= score <= 1.0)
def test_is_one_hot():
(_, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
assert (not is_one_hot(y_train))
assert (not is_one_hot(y_train.reshape(-1,1)))
assert (is_one_hot(to_categorical(y_train)))
def test_get_nb_classes():
(_, y_train), (_, y_test) = dataset_utils.get_iris_dataset_np()
# shape: (x,) - not 1-hot
nb_classes_test = get_nb_classes(y_test)
nb_classes_train = get_nb_classes(y_train)
assert (nb_classes_test == nb_classes_train)
assert (nb_classes_test == 3)
# shape: (x,1) - not 1-hot
nb_classes_test = get_nb_classes(y_test.reshape(-1,1))
assert (nb_classes_test == 3)
# shape: (x,3) - 1-hot
y = to_categorical(y_test)
nb_classes = get_nb_classes(y)
assert (nb_classes == 3)
# gaps: 1,2,4 (0,3 missing)
y_test[y_test == 0] = 4
nb_classes = get_nb_classes(y_test)
assert (nb_classes == 5)