diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 1fd2a02..99e21c4 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -23,6 +23,7 @@ class PyTorchModel(Model): class PyTorchClassifierWrapper(ArtPyTorchClassifier): """ Wrapper class for pytorch classifier model. + Extension for Pytorch ART model """ def get_step_correct(self, outputs, targets) -> int: @@ -187,6 +188,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): if self._optimizer and 'opt_state_dict' in checkpoint: self._optimizer.load_state_dict(checkpoint['opt_state_dict']) + self.model.eval() def load_latest_state_dict_checkpoint(self): """ @@ -266,14 +268,23 @@ class PyTorchClassifier(PyTorchModel): super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs) self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer) - def fit(self, train_data: PytorchData, **kwargs) -> None: + def fit(self, train_data: PytorchData, batch_size: int = 128, nb_epochs: int = 10, + save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None: """ Fit the model using the training data. :param train_data: Training data. :type train_data: `Dataset` + :param batch_size: Size of batches. + :param nb_epochs: Number of epochs to use for training. + :param save_checkpoints: Boolean, save checkpoints if True. + :param save_entire_model: Boolean, save entire model if True, else save state dict. + :param path: path for saving checkpoint. + :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently + supported for PyTorch and providing it takes no effect. """ - self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), **kwargs) + self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), batch_size, nb_epochs, + save_checkpoints, save_entire_model, path, **kwargs) def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: """ diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index b1bdb68..3f6bc11 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -2,7 +2,7 @@ import numpy as np import torch from torch import nn, optim -from apt.utils.datasets import ArrayDataset, Data, Dataset +from apt.utils.datasets import ArrayDataset from apt.utils.datasets.datasets import PytorchData from apt.utils.models import ModelOutputType from apt.utils.models.pytorch_model import PyTorchClassifier @@ -40,6 +40,7 @@ class pytorch_model(nn.Module): out = self.fc4(out) return self.classifier(out) + def test_nursery_pytorch_state_dict(): (x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5) # reduce size of training set to make attack slightly better @@ -49,22 +50,21 @@ def test_nursery_pytorch_state_dict(): x_test = x_test[:train_set_size] y_test = y_test[:train_set_size] - - - model = pytorch_model(4, 24) - # model = torch.nn.DataParallel(model) + inner_model = pytorch_model(4, 24) criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=0.01) + optimizer = optim.Adam(inner_model.parameters(), lr=0.01) - art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, - optimizer=optimizer, input_shape=(24,), - nb_classes=4) - art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False) - - pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))]) - - print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test)) - art_model.load_best_state_dict_checkpoint() + model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, + optimizer=optimizer, input_shape=(24,), + nb_classes=4) + model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100) + score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + print('Base model accuracy: ', score) + assert (0 <= score <= 1) + model.load_best_state_dict_checkpoint() + score = model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + print('Base model accuracy: ', score) + assert (0 <= score <= 1) def test_nursery_pytorch_save_entire_model(): @@ -76,8 +76,6 @@ def test_nursery_pytorch_save_entire_model(): x_test = x_test[:train_set_size] y_test = y_test[:train_set_size] - - model = pytorch_model(4, 24) # model = torch.nn.DataParallel(model) criterion = nn.CrossEntropyLoss() @@ -88,8 +86,10 @@ def test_nursery_pytorch_save_entire_model(): nb_classes=4) art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True) - pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))]) - - print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test)) + score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + print('Base model accuracy: ', score) + assert (0 <= score <= 1) art_model.load_best_model_checkpoint() - + #score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test)) + print('Base model accuracy: ', score) + assert (0 <= score <= 1)