Remove reshaping

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-10-05 13:57:04 +03:00
parent fe9eae45fc
commit ef406cea62

View file

@ -368,7 +368,7 @@ class PyTorchClassifier(PyTorchModel):
if validation_data is None:
self._art_model.fit(
x=train_data.get_samples(),
y=train_data.get_labels().reshape(-1, 1),
y=train_data.get_labels(),
batch_size=batch_size,
nb_epochs=nb_epochs,
save_checkpoints=save_checkpoints,
@ -379,9 +379,9 @@ class PyTorchClassifier(PyTorchModel):
else:
self._art_model.fit(
x=train_data.get_samples(),
y=train_data.get_labels().reshape(-1, 1),
y=train_data.get_labels(),
x_validation=validation_data.get_samples(),
y_validation=validation_data.get_labels().reshape(-1, 1),
y_validation=validation_data.get_labels(),
batch_size=batch_size,
nb_epochs=nb_epochs,
save_checkpoints=save_checkpoints,