From 3bf26b67d26665dea08fe96533a6879ae044308b Mon Sep 17 00:00:00 2001 From: olasaadi Date: Wed, 20 Jul 2022 17:36:00 +0300 Subject: [PATCH] fix --- apt/utils/models/pytorch_model.py | 1 + tests/test_pytorch.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 99e21c4..d71b164 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -122,6 +122,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): tot_correct += correct total += o_batch.shape[0] val_loss, val_acc = self._eval(x, y, num_batch, batch_size) + # print acc TODO best_acc = max(val_acc, best_acc) if save_checkpoints: if save_entire_model: diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 3f6bc11..6b3a1f7 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -58,9 +58,11 @@ def test_nursery_pytorch_state_dict(): optimizer=optimizer, input_shape=(24,), nb_classes=4) model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100) + model.load_latest_state_dict_checkpoint() score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) print('Base model accuracy: ', score) assert (0 <= score <= 1) + # python pytorch numpy model.load_best_state_dict_checkpoint() score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) print('Base model accuracy: ', score) @@ -68,6 +70,7 @@ def test_nursery_pytorch_state_dict(): def test_nursery_pytorch_save_entire_model(): + (x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5) # reduce size of training set to make attack slightly better train_set_size = 500 @@ -90,6 +93,6 @@ def test_nursery_pytorch_save_entire_model(): print('Base model accuracy: ', score) assert (0 <= score <= 1) art_model.load_best_model_checkpoint() - #score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) print('Base model accuracy: ', score) assert (0 <= score <= 1)