This commit is contained in:
olasaadi 2022-07-19 21:16:39 +03:00
parent 07e64b1f86
commit 4973fbebc6
2 changed files with 34 additions and 23 deletions

View file

@ -2,7 +2,7 @@ import numpy as np
import torch
from torch import nn, optim
from apt.utils.datasets import ArrayDataset, Data, Dataset
from apt.utils.datasets import ArrayDataset
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import ModelOutputType
from apt.utils.models.pytorch_model import PyTorchClassifier
@ -40,6 +40,7 @@ class pytorch_model(nn.Module):
out = self.fc4(out)
return self.classifier(out)
def test_nursery_pytorch_state_dict():
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
# reduce size of training set to make attack slightly better
@ -49,22 +50,21 @@ def test_nursery_pytorch_state_dict():
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
model = pytorch_model(4, 24)
# model = torch.nn.DataParallel(model)
inner_model = pytorch_model(4, 24)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False)
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
art_model.load_best_state_dict_checkpoint()
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100)
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
model.load_best_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
def test_nursery_pytorch_save_entire_model():
@ -76,8 +76,6 @@ def test_nursery_pytorch_save_entire_model():
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
model = pytorch_model(4, 24)
# model = torch.nn.DataParallel(model)
criterion = nn.CrossEntropyLoss()
@ -88,8 +86,10 @@ def test_nursery_pytorch_save_entire_model():
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True)
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
art_model.load_best_model_checkpoint()
#score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)