diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 7470e31..f3eb86c 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -94,9 +94,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): if self._optimizer is None: # pragma: no cover raise ValueError("An optimizer is needed to train the model, but none for provided.") - _y = check_and_transform_label_format(y, self.nb_classes) + y = check_and_transform_label_format(y, self.nb_classes) # Apply preprocessing - x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True) + x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True) # Check label shape y_preprocessed = self.reduce_labels(y_preprocessed) @@ -107,8 +107,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False) logger.info("Using train set for validation") else: - _y_val = check_and_transform_label_format(y_validation, self.nb_classes) - x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, _y_val, fit=False) + y_val = check_and_transform_label_format(y_validation, self.nb_classes) + x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, y_val, fit=False) # Check label shape y_val_preprocessed = self.reduce_labels(y_val_preprocessed) val_dataset = TensorDataset(torch.from_numpy(x_val_preprocessed), torch.from_numpy(y_val_preprocessed)) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index b2b94ac..c1f36c4 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -59,12 +59,12 @@ def test_nursery_pytorch_state_dict(): nb_classes=4) model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10) model.load_latest_state_dict_checkpoint() - score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + score = model.score(PytorchData(x_test.astype(np.float32), y_test)) print('Base model accuracy: ', score) assert (0 <= score <= 1) # python pytorch numpy model.load_best_state_dict_checkpoint() - score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + score = model.score(PytorchData(x_test.astype(np.float32), y_test)) print('best model accuracy: ', score) assert (0 <= score <= 1) @@ -89,10 +89,10 @@ def test_nursery_pytorch_save_entire_model(): nb_classes=4) art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10) - score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + score = art_model.score(PytorchData(x_test.astype(np.float32), y_test)) print('Base model accuracy: ', score) assert (0 <= score <= 1) art_model.load_best_model_checkpoint() - score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + score = art_model.score(PytorchData(x_test.astype(np.float32), y_test)) print('best model accuracy: ', score) assert (0 <= score <= 1)