From ef406cea62096625cc2b4ba0954546cbdbc6ae38 Mon Sep 17 00:00:00 2001 From: abigailt Date: Thu, 5 Oct 2023 13:57:04 +0300 Subject: [PATCH] Remove reshaping Signed-off-by: abigailt --- apt/utils/models/pytorch_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index a97fd33..6591779 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -368,7 +368,7 @@ class PyTorchClassifier(PyTorchModel): if validation_data is None: self._art_model.fit( x=train_data.get_samples(), - y=train_data.get_labels().reshape(-1, 1), + y=train_data.get_labels(), batch_size=batch_size, nb_epochs=nb_epochs, save_checkpoints=save_checkpoints, @@ -379,9 +379,9 @@ class PyTorchClassifier(PyTorchModel): else: self._art_model.fit( x=train_data.get_samples(), - y=train_data.get_labels().reshape(-1, 1), + y=train_data.get_labels(), x_validation=validation_data.get_samples(), - y_validation=validation_data.get_labels().reshape(-1, 1), + y_validation=validation_data.get_labels(), batch_size=batch_size, nb_epochs=nb_epochs, save_checkpoints=save_checkpoints,