This commit is contained in:
olasaadi 2022-06-06 14:02:40 +03:00
parent 302d0c4b8c
commit c954f53ad7
2 changed files with 42 additions and 42 deletions

View file

@ -9,6 +9,37 @@ from apt.utils.models.pytorch_model import PyTorchClassifier
from art.utils import load_nursery
class pytorch_model(nn.Module):
def __init__(self, num_classes, num_features):
super(pytorch_model, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 1024),
nn.Tanh(), )
self.fc2 = nn.Sequential(
nn.Linear(1024, 512),
nn.Tanh(), )
self.fc3 = nn.Sequential(
nn.Linear(512, 256),
nn.Tanh(), )
self.fc4 = nn.Sequential(
nn.Linear(256, 128),
nn.Tanh(),
)
self.classifier = nn.Linear(128, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
out = self.fc3(out)
out = self.fc4(out)
return self.classifier(out)
def test_nursery_pytorch():
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
# reduce size of training set to make attack slightly better
@ -18,39 +49,10 @@ def test_nursery_pytorch():
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
class pytorch_model(nn.Module):
def __init__(self, num_classes, num_features):
super(pytorch_model, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 1024),
nn.Tanh(), )
self.fc2 = nn.Sequential(
nn.Linear(1024, 512),
nn.Tanh(), )
self.fc3 = nn.Sequential(
nn.Linear(512, 256),
nn.Tanh(), )
self.fc4 = nn.Sequential(
nn.Linear(256, 128),
nn.Tanh(),
)
self.classifier = nn.Linear(128, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
out = self.fc3(out)
out = self.fc4(out)
return self.classifier(out)
model = pytorch_model(4, 24)
model = torch.nn.DataParallel(model)
# model = torch.nn.DataParallel(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
@ -62,4 +64,4 @@ def test_nursery_pytorch():
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
art_model.load_best_model_checkpoint()
art_model.load_best_state_dict_checkpoint()