From 50317a8d67fcd1f1c7b3f17aeea5cfc4dd4d6114 Mon Sep 17 00:00:00 2001 From: natali Date: Sat, 2 Jul 2022 23:29:49 +0300 Subject: [PATCH] fix for keras model check_and_transform_label_format requires nb_classes --- apt/utils/models/keras_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apt/utils/models/keras_model.py b/apt/utils/models/keras_model.py index 12d8ba3..0cb7252 100644 --- a/apt/utils/models/keras_model.py +++ b/apt/utils/models/keras_model.py @@ -57,7 +57,7 @@ class KerasClassifier(KerasModel): :type train_data: `Dataset` :return: None """ - y_encoded = check_and_transform_label_format(train_data.get_labels()) + y_encoded = check_and_transform_label_format(train_data.get_labels(), self._art_model.nb_classes) self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs) def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: