mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
fix
This commit is contained in:
parent
019f49861d
commit
59d8b16bb4
2 changed files with 66 additions and 1 deletions
|
|
@ -94,7 +94,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
loss.backward()
|
||||
|
||||
self._optimizer.step()
|
||||
correct = self.get_step_correct(model_outputs, o_batch)
|
||||
correct = self.get_step_correct(model_outputs[-1], o_batch)
|
||||
tot_correct += correct
|
||||
total += o_batch.shape[0]
|
||||
train_acc = float(tot_correct) / total
|
||||
|
|
|
|||
65
tests/test_pytorch.py
Normal file
65
tests/test_pytorch.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from apt.utils.datasets import ArrayDataset
|
||||
from apt.utils.models import ModelOutputType
|
||||
from apt.utils.models.pytorch_model import PyTorchClassifier
|
||||
from art.utils import load_nursery
|
||||
|
||||
|
||||
def test_nursery_pytorch():
|
||||
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
|
||||
# reduce size of training set to make attack slightly better
|
||||
train_set_size = 500
|
||||
x_train = x_train[:train_set_size]
|
||||
y_train = y_train[:train_set_size]
|
||||
x_test = x_test[:train_set_size]
|
||||
y_test = y_test[:train_set_size]
|
||||
|
||||
class pytorch_model(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, num_features):
|
||||
super(pytorch_model, self).__init__()
|
||||
|
||||
self.fc1 = nn.Sequential(
|
||||
nn.Linear(num_features, 1024),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc2 = nn.Sequential(
|
||||
nn.Linear(1024, 512),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc3 = nn.Sequential(
|
||||
nn.Linear(512, 256),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc4 = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(128, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.fc2(out)
|
||||
out = self.fc3(out)
|
||||
out = self.fc4(out)
|
||||
return self.classifier(out)
|
||||
|
||||
mlp_model = pytorch_model(4, 24)
|
||||
mlp_model = torch.nn.DataParallel(mlp_model)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(mlp_model.parameters(), lr=0.01)
|
||||
|
||||
mlp_art_model = PyTorchClassifier(model=mlp_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
mlp_art_model.fit(ArrayDataset(x_train.astype(np.float32), y_train))
|
||||
|
||||
pred = np.array([np.argmax(arr) for arr in mlp_art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
|
||||
|
||||
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
|
||||
Loading…
Add table
Add a link
Reference in a new issue