enhance calculation of nb classes + tests (#45)

* update get_nb_classes method to handle 1-hot and scalar input
This commit is contained in:
Shlomit Shachor 2022-07-05 11:32:17 +03:00 committed by GitHub
parent 50317a8d67
commit e25e58b253
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 18 deletions

View file

@ -3,7 +3,7 @@ from typing import Optional
from sklearn.preprocessing import OneHotEncoder
from sklearn.base import BaseEstimator
from apt.utils.models import Model, ModelOutputType
from apt.utils.models import Model, ModelOutputType, get_nb_classes
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
@ -59,7 +59,7 @@ class SklearnClassifier(SklearnModel):
:return: None
"""
y = train_data.get_labels()
self.nb_classes = self.get_nb_classes(y)
self.nb_classes = get_nb_classes(y)
y_encoded = check_and_transform_label_format(y, nb_classes=self.nb_classes)
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)