mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
fix
This commit is contained in:
parent
302d0c4b8c
commit
c954f53ad7
2 changed files with 42 additions and 42 deletions
|
|
@ -83,10 +83,10 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
if self._optimizer is None: # pragma: no cover
|
||||
raise ValueError("An optimizer is needed to train the model, but none for provided.")
|
||||
|
||||
y = check_and_transform_label_format(y, self.nb_classes)
|
||||
_y = check_and_transform_label_format(y, self.nb_classes)
|
||||
|
||||
# Apply preprocessing
|
||||
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
|
||||
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True)
|
||||
|
||||
# Check label shape
|
||||
y_preprocessed = self.reduce_labels(y_preprocessed)
|
||||
|
|
@ -127,14 +127,13 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
train_acc = float(tot_correct) / total
|
||||
best_acc = max(val_acc, best_acc)
|
||||
if save_checkpoints:
|
||||
self.save_checkpoint_model(is_best=best_acc <= val_acc)
|
||||
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
|
||||
|
||||
def save_checkpoint_state_dict(self, is_best: bool, additional_states: Dict = None,
|
||||
def save_checkpoint_state_dict(self, is_best: bool,
|
||||
filename="latest.tar") -> None:
|
||||
"""
|
||||
Saves checkpoint as latest.tar or best.tar
|
||||
:param is_best: whether the model is the best achieved model
|
||||
:param additional_states: additional parameters that will be saved with the model
|
||||
:param filename: checkpoint name
|
||||
:return: None
|
||||
"""
|
||||
|
|
@ -143,9 +142,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
path = checkpoint
|
||||
os.makedirs(path, exist_ok=True)
|
||||
filepath = os.path.join(path, filename)
|
||||
state = additional_states if additional_states else dict()
|
||||
state['state_dict'] = self.model.module.state_dict() \
|
||||
if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict()
|
||||
state = dict()
|
||||
state['state_dict'] = self.model.state_dict()
|
||||
state['opt_state_dict'] = self.optimizer.state_dict()
|
||||
torch.save(state, filepath)
|
||||
if is_best:
|
||||
|
|
@ -162,7 +160,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
path = checkpoint
|
||||
os.makedirs(path, exist_ok=True)
|
||||
filepath = os.path.join(path, filename)
|
||||
torch.save(self.model.module, filepath)
|
||||
torch.save(self.model, filepath)
|
||||
if is_best:
|
||||
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
|
||||
|
||||
|
|
@ -184,7 +182,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
else:
|
||||
checkpoint = torch.load(filepath)
|
||||
self.model.module.load_state_dict(checkpoint)
|
||||
self.model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
if self._optimizer and 'opt_state_dict' in checkpoint:
|
||||
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
|
||||
|
|
@ -220,7 +218,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
raise FileNotFoundError(msg)
|
||||
|
||||
else:
|
||||
self.model.module = torch.load(path)
|
||||
self._model = torch.load(filepath)
|
||||
|
||||
def load_latest_model_checkpoint(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,37 @@ from apt.utils.models.pytorch_model import PyTorchClassifier
|
|||
from art.utils import load_nursery
|
||||
|
||||
|
||||
class pytorch_model(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, num_features):
|
||||
super(pytorch_model, self).__init__()
|
||||
|
||||
self.fc1 = nn.Sequential(
|
||||
nn.Linear(num_features, 1024),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc2 = nn.Sequential(
|
||||
nn.Linear(1024, 512),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc3 = nn.Sequential(
|
||||
nn.Linear(512, 256),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc4 = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(128, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.fc2(out)
|
||||
out = self.fc3(out)
|
||||
out = self.fc4(out)
|
||||
return self.classifier(out)
|
||||
|
||||
def test_nursery_pytorch():
|
||||
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
|
||||
# reduce size of training set to make attack slightly better
|
||||
|
|
@ -18,39 +49,10 @@ def test_nursery_pytorch():
|
|||
x_test = x_test[:train_set_size]
|
||||
y_test = y_test[:train_set_size]
|
||||
|
||||
class pytorch_model(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, num_features):
|
||||
super(pytorch_model, self).__init__()
|
||||
|
||||
self.fc1 = nn.Sequential(
|
||||
nn.Linear(num_features, 1024),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc2 = nn.Sequential(
|
||||
nn.Linear(1024, 512),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc3 = nn.Sequential(
|
||||
nn.Linear(512, 256),
|
||||
nn.Tanh(), )
|
||||
|
||||
self.fc4 = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(128, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.fc2(out)
|
||||
out = self.fc3(out)
|
||||
out = self.fc4(out)
|
||||
return self.classifier(out)
|
||||
|
||||
model = pytorch_model(4, 24)
|
||||
model = torch.nn.DataParallel(model)
|
||||
# model = torch.nn.DataParallel(model)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
||||
|
||||
|
|
@ -62,4 +64,4 @@ def test_nursery_pytorch():
|
|||
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
|
||||
|
||||
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
|
||||
art_model.load_best_model_checkpoint()
|
||||
art_model.load_best_state_dict_checkpoint()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue