diff --git a/apt/utils/models/keras_model.py b/apt/utils/models/keras_model.py index 12d8ba3..0cb7252 100644 --- a/apt/utils/models/keras_model.py +++ b/apt/utils/models/keras_model.py @@ -57,7 +57,7 @@ class KerasClassifier(KerasModel): :type train_data: `Dataset` :return: None """ - y_encoded = check_and_transform_label_format(train_data.get_labels()) + y_encoded = check_and_transform_label_format(train_data.get_labels(), self._art_model.nb_classes) self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs) def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: