mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-02 16:22:37 +02:00
enhance calculation of nb classes + tests (#45)
* update get_nb_classes method to handle 1-hot and scalar input
This commit is contained in:
parent
50317a8d67
commit
e25e58b253
4 changed files with 62 additions and 18 deletions
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from sklearn.preprocessing import OneHotEncoder
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from apt.utils.models import Model, ModelOutputType
|
||||
from apt.utils.models import Model, ModelOutputType, get_nb_classes
|
||||
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
||||
|
||||
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
|
||||
|
|
@ -59,7 +59,7 @@ class SklearnClassifier(SklearnModel):
|
|||
:return: None
|
||||
"""
|
||||
y = train_data.get_labels()
|
||||
self.nb_classes = self.get_nb_classes(y)
|
||||
self.nb_classes = get_nb_classes(y)
|
||||
y_encoded = check_and_transform_label_format(y, nb_classes=self.nb_classes)
|
||||
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue