From 6f69f5557b5253b06bac9c9f5f3d98840239585b Mon Sep 17 00:00:00 2001 From: olasaadi Date: Wed, 20 Jul 2022 18:29:48 +0300 Subject: [PATCH] fix bug --- apt/utils/models/pytorch_model.py | 5 +++-- tests/test_pytorch.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index d71b164..13627cb 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -78,6 +78,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): """ # Put the model in the training mode self._model.train() + print(nb_epochs) if self._optimizer is None: # pragma: no cover raise ValueError("An optimizer is needed to train the model, but none for provided.") @@ -122,7 +123,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): tot_correct += correct total += o_batch.shape[0] val_loss, val_acc = self._eval(x, y, num_batch, batch_size) - # print acc TODO + print(val_acc) best_acc = max(val_acc, best_acc) if save_checkpoints: if save_entire_model: @@ -222,7 +223,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): raise FileNotFoundError(msg) else: - self._model = torch.load(filepath) + self._model._model = torch.load(filepath) def load_latest_model_checkpoint(self): """ diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 6b3a1f7..6095761 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -57,7 +57,7 @@ def test_nursery_pytorch_state_dict(): model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, optimizer=optimizer, input_shape=(24,), nb_classes=4) - model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100) + model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=1000) model.load_latest_state_dict_checkpoint() score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) print('Base model accuracy: ', score) @@ -65,7 +65,7 @@ def test_nursery_pytorch_state_dict(): # python pytorch numpy model.load_best_state_dict_checkpoint() score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) - print('Base model accuracy: ', score) + print('best model accuracy: ', score) assert (0 <= score <= 1) @@ -87,12 +87,12 @@ def test_nursery_pytorch_save_entire_model(): art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, optimizer=optimizer, input_shape=(24,), nb_classes=4) - art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True) + art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10) score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) print('Base model accuracy: ', score) assert (0 <= score <= 1) art_model.load_best_model_checkpoint() score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) - print('Base model accuracy: ', score) + print('best model accuracy: ', score) assert (0 <= score <= 1)