ai-privacy-toolkit/tests/test_pytorch.py

97 lines
3.4 KiB
Python
Raw Normal View History

2022-05-23 12:49:38 +03:00
import numpy as np
from torch import nn, optim
2022-06-02 15:25:07 +03:00
from apt.utils.datasets.datasets import PytorchData
2022-05-23 12:49:38 +03:00
from apt.utils.models import ModelOutputType
from apt.utils.models.pytorch_model import PyTorchClassifier
from art.utils import load_nursery
2022-06-06 14:02:40 +03:00
class pytorch_model(nn.Module):
def __init__(self, num_classes, num_features):
super(pytorch_model, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 1024),
nn.Tanh(), )
self.fc2 = nn.Sequential(
nn.Linear(1024, 512),
nn.Tanh(), )
self.fc3 = nn.Sequential(
nn.Linear(512, 256),
nn.Tanh(), )
self.fc4 = nn.Sequential(
nn.Linear(256, 128),
nn.Tanh(),
)
self.classifier = nn.Linear(128, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
out = self.fc3(out)
out = self.fc4(out)
return self.classifier(out)
2022-07-19 21:16:39 +03:00
2022-06-06 14:32:34 +03:00
def test_nursery_pytorch_state_dict():
2022-05-23 12:49:38 +03:00
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
# reduce size of training set to make attack slightly better
train_set_size = 500
x_train = x_train[:train_set_size]
y_train = y_train[:train_set_size]
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
2022-07-19 21:16:39 +03:00
inner_model = pytorch_model(4, 24)
2022-05-23 12:49:38 +03:00
criterion = nn.CrossEntropyLoss()
2022-07-19 21:16:39 +03:00
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
2022-05-23 12:49:38 +03:00
2022-08-01 18:12:03 +03:00
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
2022-07-19 21:16:39 +03:00
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
2022-07-20 17:36:00 +03:00
model.load_latest_state_dict_checkpoint()
2022-07-26 18:37:44 +03:00
score = model.score(PytorchData(x_test.astype(np.float32), y_test))
2022-07-19 21:16:39 +03:00
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
2022-07-20 17:36:00 +03:00
# python pytorch numpy
2022-07-19 21:16:39 +03:00
model.load_best_state_dict_checkpoint()
2022-07-26 18:37:44 +03:00
score = model.score(PytorchData(x_test.astype(np.float32), y_test))
2022-07-20 18:29:48 +03:00
print('best model accuracy: ', score)
2022-07-19 21:16:39 +03:00
assert (0 <= score <= 1)
2022-06-06 14:32:34 +03:00
def test_nursery_pytorch_save_entire_model():
2022-07-20 17:36:00 +03:00
2022-06-06 14:32:34 +03:00
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
# reduce size of training set to make attack slightly better
train_set_size = 500
x_train = x_train[:train_set_size]
y_train = y_train[:train_set_size]
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
model = pytorch_model(4, 24)
# model = torch.nn.DataParallel(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
2022-08-01 18:12:03 +03:00
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
2022-06-06 14:32:34 +03:00
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
2022-07-20 18:29:48 +03:00
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)
2022-06-06 14:32:34 +03:00
2022-07-26 18:37:44 +03:00
score = art_model.score(PytorchData(x_test.astype(np.float32), y_test))
2022-07-19 21:16:39 +03:00
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
2022-06-06 14:32:34 +03:00
art_model.load_best_model_checkpoint()
2022-07-26 18:37:44 +03:00
score = art_model.score(PytorchData(x_test.astype(np.float32), y_test))
2022-07-20 18:29:48 +03:00
print('best model accuracy: ', score)
2022-07-19 21:16:39 +03:00
assert (0 <= score <= 1)