diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 877abe4..a57830b 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -65,7 +65,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): return total_loss / total, float(correct) / total def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, - save_checkpoints: bool = True, **kwargs) -> None: + save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None: """ Fit the classifier on the training set `(x, y)`. :param x: Training data. @@ -74,6 +74,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): :param batch_size: Size of batches. :param nb_epochs: Number of epochs to use for training. :param save_checkpoints: Boolean, save checkpoints if True. + :param save_entire_model: Boolean, save entire model if True, else save state dict. + :param path: path for saving checkpoint. :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch and providing it takes no effect. """ @@ -127,18 +129,22 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): train_acc = float(tot_correct) / total best_acc = max(val_acc, best_acc) if save_checkpoints: - self.save_checkpoint_state_dict(is_best=best_acc <= val_acc) + if save_entire_model: + self.save_checkpoint_model(is_best=best_acc <= val_acc) + else: + self.save_checkpoint_state_dict(is_best=best_acc <= val_acc) - def save_checkpoint_state_dict(self, is_best: bool, + def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None: """ Saves checkpoint as latest.tar or best.tar :param is_best: whether the model is the best achieved model + :param path: path for saving checkpoint :param filename: checkpoint name :return: None """ # add path - checkpoint = os.path.join(os.getcwd(), 'checkpoints') + checkpoint = os.path.join(path, 'checkpoints') path = checkpoint os.makedirs(path, exist_ok=True) filepath = os.path.join(path, filename) @@ -149,14 +155,15 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): if is_best: shutil.copyfile(filepath, os.path.join(path, 'model_best.tar')) - def save_checkpoint_model(self, is_best: bool, filename="latest.tar") -> None: + def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None: """ Saves checkpoint as latest.tar or best.tar :param is_best: whether the model is the best achieved model + :param path: path for saving checkpoint :param filename: checkpoint name :return: None """ - checkpoint = os.path.join(os.getcwd(), 'checkpoints') + checkpoint = os.path.join(path, 'checkpoints') path = checkpoint os.makedirs(path, exist_ok=True) filepath = os.path.join(path, filename) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index aaa6830..b1bdb68 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -40,7 +40,7 @@ class pytorch_model(nn.Module): out = self.fc4(out) return self.classifier(out) -def test_nursery_pytorch(): +def test_nursery_pytorch_state_dict(): (x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5) # reduce size of training set to make attack slightly better train_set_size = 500 @@ -59,9 +59,37 @@ def test_nursery_pytorch(): art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, optimizer=optimizer, input_shape=(24,), nb_classes=4) - art_model.fit(PytorchData(x_train.astype(np.float32), y_train)) + art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False) pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))]) print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test)) art_model.load_best_state_dict_checkpoint() + + +def test_nursery_pytorch_save_entire_model(): + (x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5) + # reduce size of training set to make attack slightly better + train_set_size = 500 + x_train = x_train[:train_set_size] + y_train = y_train[:train_set_size] + x_test = x_test[:train_set_size] + y_test = y_test[:train_set_size] + + + + model = pytorch_model(4, 24) + # model = torch.nn.DataParallel(model) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.01) + + art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, + optimizer=optimizer, input_shape=(24,), + nb_classes=4) + art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True) + + pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))]) + + print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test)) + art_model.load_best_model_checkpoint() +