mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
Remove reshaping
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
fe9eae45fc
commit
ef406cea62
1 changed files with 3 additions and 3 deletions
|
|
@ -368,7 +368,7 @@ class PyTorchClassifier(PyTorchModel):
|
|||
if validation_data is None:
|
||||
self._art_model.fit(
|
||||
x=train_data.get_samples(),
|
||||
y=train_data.get_labels().reshape(-1, 1),
|
||||
y=train_data.get_labels(),
|
||||
batch_size=batch_size,
|
||||
nb_epochs=nb_epochs,
|
||||
save_checkpoints=save_checkpoints,
|
||||
|
|
@ -379,9 +379,9 @@ class PyTorchClassifier(PyTorchModel):
|
|||
else:
|
||||
self._art_model.fit(
|
||||
x=train_data.get_samples(),
|
||||
y=train_data.get_labels().reshape(-1, 1),
|
||||
y=train_data.get_labels(),
|
||||
x_validation=validation_data.get_samples(),
|
||||
y_validation=validation_data.get_labels().reshape(-1, 1),
|
||||
y_validation=validation_data.get_labels(),
|
||||
batch_size=batch_size,
|
||||
nb_epochs=nb_epochs,
|
||||
save_checkpoints=save_checkpoints,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue