fix for keras model check_and_transform_label_format requires nb_classes

This commit is contained in:
natali 2022-07-02 23:29:49 +03:00 committed by abigailgold
parent c6eb553a9f
commit 50317a8d67

View file

@ -57,7 +57,7 @@ class KerasClassifier(KerasModel):
:type train_data: `Dataset`
:return: None
"""
y_encoded = check_and_transform_label_format(train_data.get_labels())
y_encoded = check_and_transform_label_format(train_data.get_labels(), self._art_model.nb_classes)
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: