From c954f53ad7349770eaae0b521c60bf0fa8fc7847 Mon Sep 17 00:00:00 2001 From: olasaadi Date: Mon, 6 Jun 2022 14:02:40 +0300 Subject: [PATCH] fix --- apt/utils/models/pytorch_model.py | 20 +++++----- tests/test_pytorch.py | 64 ++++++++++++++++--------------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 18d57b5..877abe4 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -83,10 +83,10 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): if self._optimizer is None: # pragma: no cover raise ValueError("An optimizer is needed to train the model, but none for provided.") - y = check_and_transform_label_format(y, self.nb_classes) + _y = check_and_transform_label_format(y, self.nb_classes) # Apply preprocessing - x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True) + x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True) # Check label shape y_preprocessed = self.reduce_labels(y_preprocessed) @@ -127,14 +127,13 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): train_acc = float(tot_correct) / total best_acc = max(val_acc, best_acc) if save_checkpoints: - self.save_checkpoint_model(is_best=best_acc <= val_acc) + self.save_checkpoint_state_dict(is_best=best_acc <= val_acc) - def save_checkpoint_state_dict(self, is_best: bool, additional_states: Dict = None, + def save_checkpoint_state_dict(self, is_best: bool, filename="latest.tar") -> None: """ Saves checkpoint as latest.tar or best.tar :param is_best: whether the model is the best achieved model - :param additional_states: additional parameters that will be saved with the model :param filename: checkpoint name :return: None """ @@ -143,9 +142,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): path = checkpoint os.makedirs(path, exist_ok=True) filepath = os.path.join(path, filename) - state = additional_states if additional_states else dict() - state['state_dict'] = self.model.module.state_dict() \ - if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict() + state = dict() + state['state_dict'] = self.model.state_dict() state['opt_state_dict'] = self.optimizer.state_dict() torch.save(state, filepath) if is_best: @@ -162,7 +160,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): path = checkpoint os.makedirs(path, exist_ok=True) filepath = os.path.join(path, filename) - torch.save(self.model.module, filepath) + torch.save(self.model, filepath) if is_best: shutil.copyfile(filepath, os.path.join(path, 'model_best.tar')) @@ -184,7 +182,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): else: checkpoint = torch.load(filepath) - self.model.module.load_state_dict(checkpoint) + self.model.load_state_dict(checkpoint['state_dict']) if self._optimizer and 'opt_state_dict' in checkpoint: self._optimizer.load_state_dict(checkpoint['opt_state_dict']) @@ -220,7 +218,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): raise FileNotFoundError(msg) else: - self.model.module = torch.load(path) + self._model = torch.load(filepath) def load_latest_model_checkpoint(self): """ diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 8841303..aaa6830 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -9,6 +9,37 @@ from apt.utils.models.pytorch_model import PyTorchClassifier from art.utils import load_nursery +class pytorch_model(nn.Module): + + def __init__(self, num_classes, num_features): + super(pytorch_model, self).__init__() + + self.fc1 = nn.Sequential( + nn.Linear(num_features, 1024), + nn.Tanh(), ) + + self.fc2 = nn.Sequential( + nn.Linear(1024, 512), + nn.Tanh(), ) + + self.fc3 = nn.Sequential( + nn.Linear(512, 256), + nn.Tanh(), ) + + self.fc4 = nn.Sequential( + nn.Linear(256, 128), + nn.Tanh(), + ) + + self.classifier = nn.Linear(128, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + out = self.fc3(out) + out = self.fc4(out) + return self.classifier(out) + def test_nursery_pytorch(): (x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5) # reduce size of training set to make attack slightly better @@ -18,39 +49,10 @@ def test_nursery_pytorch(): x_test = x_test[:train_set_size] y_test = y_test[:train_set_size] - class pytorch_model(nn.Module): - def __init__(self, num_classes, num_features): - super(pytorch_model, self).__init__() - - self.fc1 = nn.Sequential( - nn.Linear(num_features, 1024), - nn.Tanh(), ) - - self.fc2 = nn.Sequential( - nn.Linear(1024, 512), - nn.Tanh(), ) - - self.fc3 = nn.Sequential( - nn.Linear(512, 256), - nn.Tanh(), ) - - self.fc4 = nn.Sequential( - nn.Linear(256, 128), - nn.Tanh(), - ) - - self.classifier = nn.Linear(128, num_classes) - - def forward(self, x): - out = self.fc1(x) - out = self.fc2(out) - out = self.fc3(out) - out = self.fc4(out) - return self.classifier(out) model = pytorch_model(4, 24) - model = torch.nn.DataParallel(model) + # model = torch.nn.DataParallel(model) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.01) @@ -62,4 +64,4 @@ def test_nursery_pytorch(): pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))]) print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test)) - art_model.load_best_model_checkpoint() + art_model.load_best_state_dict_checkpoint()