mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
fix for keras model check_and_transform_label_format requires nb_classes
This commit is contained in:
parent
c6eb553a9f
commit
50317a8d67
1 changed files with 1 additions and 1 deletions
|
|
@ -57,7 +57,7 @@ class KerasClassifier(KerasModel):
|
|||
:type train_data: `Dataset`
|
||||
:return: None
|
||||
"""
|
||||
y_encoded = check_and_transform_label_format(train_data.get_labels())
|
||||
y_encoded = check_and_transform_label_format(train_data.get_labels(), self._art_model.nb_classes)
|
||||
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)
|
||||
|
||||
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue