This commit is contained in:
olasaadi 2022-07-20 17:36:00 +03:00
parent 4973fbebc6
commit 3bf26b67d2
2 changed files with 5 additions and 1 deletions

View file

@ -122,6 +122,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
tot_correct += correct tot_correct += correct
total += o_batch.shape[0] total += o_batch.shape[0]
val_loss, val_acc = self._eval(x, y, num_batch, batch_size) val_loss, val_acc = self._eval(x, y, num_batch, batch_size)
# print acc TODO
best_acc = max(val_acc, best_acc) best_acc = max(val_acc, best_acc)
if save_checkpoints: if save_checkpoints:
if save_entire_model: if save_entire_model:

View file

@ -58,9 +58,11 @@ def test_nursery_pytorch_state_dict():
optimizer=optimizer, input_shape=(24,), optimizer=optimizer, input_shape=(24,),
nb_classes=4) nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100) model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100)
model.load_latest_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score) print('Base model accuracy: ', score)
assert (0 <= score <= 1) assert (0 <= score <= 1)
# python pytorch numpy
model.load_best_state_dict_checkpoint() model.load_best_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score) print('Base model accuracy: ', score)
@ -68,6 +70,7 @@ def test_nursery_pytorch_state_dict():
def test_nursery_pytorch_save_entire_model(): def test_nursery_pytorch_save_entire_model():
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5) (x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
# reduce size of training set to make attack slightly better # reduce size of training set to make attack slightly better
train_set_size = 500 train_set_size = 500
@ -90,6 +93,6 @@ def test_nursery_pytorch_save_entire_model():
print('Base model accuracy: ', score) print('Base model accuracy: ', score)
assert (0 <= score <= 1) assert (0 <= score <= 1)
art_model.load_best_model_checkpoint() art_model.load_best_model_checkpoint()
#score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score) print('Base model accuracy: ', score)
assert (0 <= score <= 1) assert (0 <= score <= 1)