diff --git a/apt/utils/datasets/datasets.py b/apt/utils/datasets/datasets.py index 7eae95f..2e0a70a 100644 --- a/apt/utils/datasets/datasets.py +++ b/apt/utils/datasets/datasets.py @@ -19,9 +19,42 @@ from torch import Tensor logger = logging.getLogger(__name__) -INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, pd.Series, List, Tensor] +INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor] OUTPUT_DATA_ARRAY_TYPE = np.ndarray -DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame, pd.Series] +DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame] + + +def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE: + + """ + converts from INPUT_DATA_ARRAY_TYPE to numpy array + """ + if type(arr) == np.ndarray: + return arr + if type(arr) == pd.DataFrame or type(arr) == pd.Series: + return arr.to_numpy() + if isinstance(arr, list): + return np.array(arr) + if type(arr) == Tensor: + return arr.detach().cpu().numpy() + + raise ValueError("Non supported type: ", type(arr).__name__) + + +def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor: + """ + converts from INPUT_DATA_ARRAY_TYPE to torch tensor array + """ + if type(arr) == np.ndarray: + return torch.from_numpy(arr) + if type(arr) == pd.DataFrame or type(arr) == pd.Series: + return torch.from_numpy(arr.to_numpy()) + if isinstance(arr, list): + return torch.tensor(arr) + if type(arr) == Tensor: + return arr + + raise ValueError("Non supported type: ", type(arr).__name__) class Dataset(metaclass=ABCMeta): @@ -58,46 +91,6 @@ class Dataset(metaclass=ABCMeta): """ raise NotImplementedError - def _array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE: - """ - Converts from INPUT_DATA_ARRAY_TYPE to numpy array - - :param arr: the array to transform - :type arr: numpy array or pandas DataFrame or list or pytorch Tensor - :return: the array transformed into a numpy array - """ - if type(arr) == np.ndarray: - return arr - if type(arr) == pd.DataFrame or type(arr) == pd.Series: - self.is_pandas = True - return arr.to_numpy() - if isinstance(arr, list): - return np.array(arr) - if type(arr) == Tensor: - return arr.detach().cpu().numpy() - - raise ValueError('Non supported type: ', type(arr).__name__) - - def _array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor: - """ - Converts from INPUT_DATA_ARRAY_TYPE to torch tensor array - - :param arr: the array to transform - :type arr: numpy array or pandas DataFrame or list or pytorch Tensor - :return: the array transformed into a pytorch Tensor - """ - if type(arr) == np.ndarray: - return torch.from_numpy(arr) - if type(arr) == pd.DataFrame or type(arr) == pd.Series: - self.is_pandas = True - return torch.from_numpy(arr.to_numpy()) - if isinstance(arr, list): - return torch.tensor(arr) - if type(arr) == Tensor: - return arr - - raise ValueError('Non supported type: ', type(arr).__name__) - class StoredDataset(Dataset): """Abstract Class for a Dataset that can be downloaded from a URL and stored in a file""" @@ -146,7 +139,7 @@ class StoredDataset(Dataset): os.makedirs(dest_path, exist_ok=True) logger.info("Downloading the dataset...") urllib.request.urlretrieve(url, file_path) - logger.info('Dataset Downloaded') + logger.info("Dataset Downloaded") if unzip: StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False) @@ -205,7 +198,7 @@ class StoredDataset(Dataset): logger.info("Shuffling data") np.random.shuffle(data) - debug_data = data[:int(len(data) * ratio)] + debug_data = data[: int(len(data) * ratio)] logger.info(f"Saving {ratio} of the data to {dest_datafile}") np.savetxt(dest_datafile, debug_data, delimiter=delimiter, fmt=fmt) @@ -224,17 +217,19 @@ class ArrayDataset(Dataset): def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names: Optional[list] = None, **kwargs): - self.is_pandas = False + self.is_pandas = self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series + self.features_names = features_names - self._y = self._array2numpy(y) if y is not None else None - self._x = self._array2numpy(x) + self._y = array2numpy(y) if y is not None else None + self._x = array2numpy(x) + if self.is_pandas: if features_names and not np.array_equal(features_names, x.columns): raise ValueError("The supplied features are not the same as in the data features") self.features_names = x.columns.to_list() if self._y is not None and len(self._x) != len(self._y): - raise ValueError('Non equivalent lengths of x and y') + raise ValueError("Non equivalent lengths of x and y") def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE: """ @@ -278,9 +273,9 @@ class DatasetWithPredictions(Dataset): y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names: Optional[list] = None, **kwargs): self.is_pandas = False self.features_names = features_names - self._pred = self._array2numpy(pred) - self._y = self._array2numpy(y) if y is not None else None - self._x = self._array2numpy(x) if x is not None else None + self._pred = array2numpy(pred) + self._y = array2numpy(y) if y is not None else None + self._x = array2numpy(x) if x is not None else None if self.is_pandas and x is not None: if features_names and not np.array_equal(features_names, x.columns): raise ValueError("The supplied features are not the same as in the data features") @@ -327,14 +322,16 @@ class PytorchData(Dataset): :type y: numpy array or pandas DataFrame or list or pytorch Tensor, optional """ def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs): - self.is_pandas = False - self._y = self._array2torch_tensor(y) if y is not None else None - self._x = self._array2torch_tensor(x) + self._y = array2torch_tensor(y) if y is not None else None + self._x = array2torch_tensor(x) + + self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series + if self.is_pandas: self.features_names = x.columns if self._y is not None and len(self._x) != len(self._y): - raise ValueError('Non equivalent lengths of x and y') + raise ValueError("Non equivalent lengths of x and y") if self._y is not None: self.__getitem__ = self.get_item @@ -347,7 +344,7 @@ class PytorchData(Dataset): :return: samples as numpy array """ - return self._array2numpy(self._x) + return array2numpy(self._x) def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE: """ @@ -355,7 +352,7 @@ class PytorchData(Dataset): :return: labels as numpy array """ - return self._array2numpy(self._y) if self._y is not None else None + return array2numpy(self._y) if self._y is not None else None def get_predictions(self) -> OUTPUT_DATA_ARRAY_TYPE: """ @@ -392,6 +389,7 @@ class PytorchData(Dataset): class DatasetFactory: """Factory class for dataset creation""" + registry = {} @classmethod @@ -406,7 +404,7 @@ class DatasetFactory: def inner_wrapper(wrapped_class: Type[Dataset]) -> Any: if name in cls.registry: - logger.warning('Dataset %s already exists. Will replace it', name) + logger.warning("Dataset %s already exists. Will replace it", name) cls.registry[name] = wrapped_class return wrapped_class @@ -428,7 +426,7 @@ class DatasetFactory: :return: An instance of the dataset that is created. """ if name not in cls.registry: - msg = f'Dataset {name} does not exist in the registry' + msg = f"Dataset {name} does not exist in the registry" logger.error(msg) raise ValueError(msg) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py new file mode 100644 index 0000000..f3eb86c --- /dev/null +++ b/apt/utils/models/pytorch_model.py @@ -0,0 +1,425 @@ +""" Pytorch Model Wrapper""" +import os +import shutil +import logging + +from typing import Optional, Tuple +import numpy as np +import torch +from torch.utils.data import DataLoader, TensorDataset + +from art.utils import check_and_transform_label_format, logger +from apt.utils.datasets.datasets import PytorchData +from apt.utils.models import Model, ModelOutputType +from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE +from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier + + +logger = logging.getLogger(__name__) + + +class PyTorchModel(Model): + """ + Wrapper class for pytorch models. + """ + + +class PyTorchClassifierWrapper(ArtPyTorchClassifier): + """ + Wrapper class for pytorch classifier model. + Extension for Pytorch ART model + """ + + def get_step_correct(self, outputs, targets) -> int: + """get number of correctly classified labels""" + if len(outputs) != len(targets): + raise ValueError("outputs and targets should be the same length.") + if self.nb_classes > 1: + return int(torch.sum(torch.argmax(outputs, axis=-1) == targets).item()) + else: + return int(torch.sum(torch.round(outputs, axis=-1) == targets).item()) + + def _eval(self, loader: DataLoader): + """inner function for model evaluation""" + self.model.eval() + + total_loss = 0 + correct = 0 + total = 0 + + for inputs, targets in loader: + inputs = inputs.to(self._device) + targets = targets.to(self._device) + + outputs = self.model(inputs) + loss = self._loss(outputs, targets) + total_loss += loss.item() * targets.size(0) + total += targets.size(0) + correct += self.get_step_correct(outputs, targets) + + return total_loss / total, float(correct) / total + + def fit( + self, + x: np.ndarray, + y: np.ndarray, + x_validation: np.ndarray = None, + y_validation: np.ndarray = None, + batch_size: int = 128, + nb_epochs: int = 10, + save_checkpoints: bool = True, + save_entire_model=True, + path=os.getcwd(), + **kwargs, + ) -> None: + """ + Fit the classifier on the training set `(x, y)`. + :param x: Training data. + :param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels + of shape (nb_samples,). + :param x_validation: Validation data (optional). + :param y_validation: Target validation values (class labels) one-hot-encoded of shape + (nb_samples, nb_classes) or index labels of shape (nb_samples,) (optional). + :param batch_size: Size of batches. + :param nb_epochs: Number of epochs to use for training. + :param save_checkpoints: Boolean, save checkpoints if True. + :param save_entire_model: Boolean, save entire model if True, else save state dict. + :param path: path for saving checkpoint. + :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently + supported for PyTorch and providing it takes no effect. + """ + # Put the model in the training mode + self._model.train() + + if self._optimizer is None: # pragma: no cover + raise ValueError("An optimizer is needed to train the model, but none for provided.") + + y = check_and_transform_label_format(y, self.nb_classes) + # Apply preprocessing + x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True) + # Check label shape + y_preprocessed = self.reduce_labels(y_preprocessed) + + train_dataset = TensorDataset(torch.from_numpy(x_preprocessed), torch.from_numpy(y_preprocessed)) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + if x_validation is None or y_validation is None: + val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + logger.info("Using train set for validation") + else: + y_val = check_and_transform_label_format(y_validation, self.nb_classes) + x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, y_val, fit=False) + # Check label shape + y_val_preprocessed = self.reduce_labels(y_val_preprocessed) + val_dataset = TensorDataset(torch.from_numpy(x_val_preprocessed), torch.from_numpy(y_val_preprocessed)) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + + # Start training + for epoch in range(nb_epochs): + tot_correct = 0 + total = 0 + best_acc = 0 + # Shuffle the examples + + # Train for one epoch + for inputs, targets in train_loader: + inputs = inputs.to(self._device) + targets = targets.to(self._device) + # Zero the parameter gradients + self._optimizer.zero_grad() + + # Perform prediction + model_outputs = self._model(inputs) + + # Form the loss function + loss = self._loss(model_outputs[-1], targets) + + loss.backward() + + self._optimizer.step() + correct = self.get_step_correct(model_outputs[-1], targets) + tot_correct += correct + total += targets.shape[0] + + val_loss, val_acc = self._eval(val_loader) + logger.info(f"Epoch{epoch + 1}/{nb_epochs} Val_Loss: {val_loss}, Val_Acc: {val_acc}") + + best_acc = max(val_acc, best_acc) + if save_checkpoints: + if save_entire_model: + self.save_checkpoint_model(is_best=best_acc <= val_acc, path=path) + else: + self.save_checkpoint_state_dict(is_best=best_acc <= val_acc, path=path) + + def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None: + """ + Saves checkpoint as latest.tar or best.tar + :param is_best: whether the model is the best achieved model + :param path: path for saving checkpoint + :param filename: checkpoint name + :return: None + """ + # add path + checkpoint = os.path.join(path, "checkpoints") + path = checkpoint + os.makedirs(path, exist_ok=True) + filepath = os.path.join(path, filename) + state = dict() + state["state_dict"] = self.model.state_dict() + state["opt_state_dict"] = self.optimizer.state_dict() + + logger.info(f"Saving checkpoint state dictionary: {filepath}") + torch.save(state, filepath) + if is_best: + shutil.copyfile(filepath, os.path.join(path, "model_best.tar")) + logger.info(f"Saving best state dictionary checkpoint: {os.path.join(path, 'model_best.tar')}") + + def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None: + """ + Saves checkpoint as latest.tar or best.tar + :param is_best: whether the model is the best achieved model + :param path: path for saving checkpoint + :param filename: checkpoint name + :return: None + """ + checkpoint = os.path.join(path, "checkpoints") + path = checkpoint + os.makedirs(path, exist_ok=True) + filepath = os.path.join(path, filename) + logger.info(f"Saving checkpoint model : {filepath}") + torch.save(self.model, filepath) + if is_best: + shutil.copyfile(filepath, os.path.join(path, "model_best.tar")) + logger.info(f"Saving best checkpoint model: {os.path.join(path, 'model_best.tar')}") + + def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None): + """ + Load model only based on the check point path + :param model_name: check point filename + :param path: checkpoint path (default current work dir) + :return: loaded model + """ + if path is None: + path = os.path.join(os.getcwd(), "checkpoints") + + filepath = os.path.join(path, model_name) + if not os.path.exists(filepath): + msg = f"Model file {filepath} not found" + logger.error(msg) + raise FileNotFoundError(msg) + + else: + checkpoint = torch.load(filepath) + self.model.load_state_dict(checkpoint["state_dict"]) + self.model.to(self.device) + + if self._optimizer and "opt_state_dict" in checkpoint: + self._optimizer.load_state_dict(checkpoint["opt_state_dict"]) + self.model.eval() + + def load_latest_state_dict_checkpoint(self): + """ + Load model state dict only based on the check point path (latest.tar) + :return: loaded model + """ + self.load_checkpoint_state_dict_by_path("latest.tar") + + def load_best_state_dict_checkpoint(self): + """ + Load model state dict only based on the check point path (model_best.tar) + :return: loaded model + """ + self.load_checkpoint_state_dict_by_path("model_best.tar") + + def load_checkpoint_model_by_path(self, model_name: str, path: str = None): + """ + Load model only based on the check point path + :param model_name: check point filename + :param path: checkpoint path (default current work dir) + :return: loaded model + """ + if path is None: + path = os.path.join(os.getcwd(), "checkpoints") + + filepath = os.path.join(path, model_name) + if not os.path.exists(filepath): + msg = f"Model file {filepath} not found" + logger.error(msg) + raise FileNotFoundError(msg) + + else: + self._model._model = torch.load(filepath, map_location=self.device) + self.model.to(self.device) + self.model.eval() + + def load_latest_model_checkpoint(self): + """ + Load entire model only based on the check point path (latest.tar) + :return: loaded model + """ + self.load_checkpoint_model_by_path("latest.tar") + + def load_best_model_checkpoint(self): + """ + Load entire model only based on the check point path (model_best.tar) + :return: loaded model + """ + self.load_checkpoint_model_by_path("model_best.tar") + + +class PyTorchClassifier(PyTorchModel): + """ + Wrapper class for pytorch classification models. + """ + + def __init__( + self, + model: "torch.nn.Module", + output_type: ModelOutputType, + loss: "torch.nn.modules.loss._Loss", + input_shape: Tuple[int, ...], + nb_classes: int, + optimizer: "torch.optim.Optimizer", + black_box_access: Optional[bool] = True, + unlimited_queries: Optional[bool] = True, + **kwargs, + ): + """ + Initialization specifically for the PyTorch-based implementation. + + :param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits + output should be preferred where possible to ensure attack efficiency. + :param output_type: The type of output the model yields (vector/label only for classifiers, + value for regressors) + :param loss: The loss function for which to compute gradients for training. The target label must be raw + categorical, i.e. not converted to one-hot encoding. + :param input_shape: The shape of one input instance. + :param optimizer: The optimizer used to train the classifier. + :param black_box_access: Boolean describing the type of deployment of the model (when in production). + Set to True if the model is only available via query (API) access, i.e., + only the outputs of the model are exposed, and False if the model internals + are also available. Optional, Default is True. + :param unlimited_queries: If black_box_access is True, this boolean indicates whether a user can perform + unlimited queries to the model API or whether there is a limit to the number of + queries that can be submitted. Optional, Default is True. + """ + super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs) + self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer) + + def fit( + self, + train_data: PytorchData, + validation_data: PytorchData = None, + batch_size: int = 128, + nb_epochs: int = 10, + save_checkpoints: bool = True, + save_entire_model=True, + path=os.getcwd(), + **kwargs, + ) -> None: + """ + Fit the model using the training data. + + :param train_data: Training data. + :type train_data: `PytorchData` + :param validation_data: Training data. + :type train_data: `PytorchData` + :param batch_size: Size of batches. + :param nb_epochs: Number of epochs to use for training. + :param save_checkpoints: Boolean, save checkpoints if True. + :param save_entire_model: Boolean, save entire model if True, else save state dict. + :param path: path for saving checkpoint. + :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently + supported for PyTorch and providing it takes no effect. + """ + if validation_data is None: + self._art_model.fit( + x=train_data.get_samples(), + y=train_data.get_labels().reshape(-1, 1), + batch_size=batch_size, + nb_epochs=nb_epochs, + save_checkpoints=save_checkpoints, + save_entire_model=save_entire_model, + path=path, + **kwargs, + ) + else: + self._art_model.fit( + x=train_data.get_samples(), + y=train_data.get_labels().reshape(-1, 1), + x_validation=validation_data.get_samples(), + y_validation=validation_data.get_labels().reshape(-1, 1), + batch_size=batch_size, + nb_epochs=nb_epochs, + save_checkpoints=save_checkpoints, + save_entire_model=save_entire_model, + path=path, + **kwargs, + ) + + def predict(self, x: PytorchData, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: + """ + Perform predictions using the model for input `x`. + + :param x: Input samples. + :type x: `np.ndarray` or `pandas.DataFrame` + :return: Predictions from the model (class probabilities, if supported). + """ + return self._art_model.predict(x.get_samples(), **kwargs) + + def score(self, test_data: PytorchData, **kwargs): + """ + Score the model using test data. + :param test_data: Test data. + :type test_data: `PytorchData` + :return: the score as float (between 0 and 1) + """ + y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes) + predicted = self.predict(test_data) + return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0] + + def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None): + """ + Load model only based on the check point path + :param model_name: check point filename + :param path: checkpoint path (default current work dir) + :return: loaded model + """ + self._art_model.load_checkpoint_state_dict_by_path(model_name, path) + + def load_latest_state_dict_checkpoint(self): + """ + Load model state dict only based on the check point path (latest.tar) + :return: loaded model + """ + self._art_model.load_latest_state_dict_checkpoint() + + def load_best_state_dict_checkpoint(self): + """ + Load model state dict only based on the check point path (model_best.tar) + :return: loaded model + """ + self._art_model.load_best_state_dict_checkpoint() + + def load_checkpoint_model_by_path(self, model_name: str, path: str = None): + """ + Load model only based on the check point path + :param model_name: check point filename + :param path: checkpoint path (default current work dir) + :return: loaded model + """ + self._art_model.load_checkpoint_model_by_path(model_name, path) + + def load_latest_model_checkpoint(self): + """ + Load entire model only based on the check point path (latest.tar) + :return: loaded model + """ + self._art_model.load_latest_model_checkpoint() + + def load_best_model_checkpoint(self): + """ + Load entire model only based on the check point path (model_best.tar) + :return: loaded model + """ + self._art_model.load_best_model_checkpoint() diff --git a/requirements.txt b/requirements.txt index 849ca96..0cf46eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ -numpy>=1.22 -pandas==1.1.0 -scipy==1.4.1 -scikit-learn==0.22.2 +numpy~=1.22 +pandas~=1.1.0 +scipy>=1.4.1 +scikit-learn>=0.22.2 torch>=1.8.0 adversarial-robustness-toolbox>=1.11.0 # testing -pytest==5.4.2 +pytest>=5.4.2 diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py new file mode 100644 index 0000000..44e9013 --- /dev/null +++ b/tests/test_pytorch.py @@ -0,0 +1,98 @@ +import numpy as np +import torch +from torch import nn, optim + +from apt.utils.datasets import ArrayDataset +from apt.utils.datasets.datasets import PytorchData +from apt.utils.models import ModelOutputType +from apt.utils.models.pytorch_model import PyTorchClassifier +from art.utils import load_nursery + + +class pytorch_model(nn.Module): + + def __init__(self, num_classes, num_features): + super(pytorch_model, self).__init__() + + self.fc1 = nn.Sequential( + nn.Linear(num_features, 1024), + nn.Tanh(), ) + + self.fc2 = nn.Sequential( + nn.Linear(1024, 512), + nn.Tanh(), ) + + self.fc3 = nn.Sequential( + nn.Linear(512, 256), + nn.Tanh(), ) + + self.fc4 = nn.Sequential( + nn.Linear(256, 128), + nn.Tanh(), + ) + + self.classifier = nn.Linear(128, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + out = self.fc3(out) + out = self.fc4(out) + return self.classifier(out) + + +def test_nursery_pytorch_state_dict(): + (x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5) + # reduce size of training set to make attack slightly better + train_set_size = 500 + x_train = x_train[:train_set_size] + y_train = y_train[:train_set_size] + x_test = x_test[:train_set_size] + y_test = y_test[:train_set_size] + + inner_model = pytorch_model(4, 24) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(inner_model.parameters(), lr=0.01) + + model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion, + optimizer=optimizer, input_shape=(24,), + nb_classes=4) + model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10) + model.load_latest_state_dict_checkpoint() + score = model.score(PytorchData(x_test.astype(np.float32), y_test)) + print('Base model accuracy: ', score) + assert (0 <= score <= 1) + # python pytorch numpy + model.load_best_state_dict_checkpoint() + score = model.score(PytorchData(x_test.astype(np.float32), y_test)) + print('best model accuracy: ', score) + assert (0 <= score <= 1) + + +def test_nursery_pytorch_save_entire_model(): + + (x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5) + # reduce size of training set to make attack slightly better + train_set_size = 500 + x_train = x_train[:train_set_size] + y_train = y_train[:train_set_size] + x_test = x_test[:train_set_size] + y_test = y_test[:train_set_size] + + model = pytorch_model(4, 24) + # model = torch.nn.DataParallel(model) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.01) + + art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion, + optimizer=optimizer, input_shape=(24,), + nb_classes=4) + art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10) + + score = art_model.score(PytorchData(x_test.astype(np.float32), y_test)) + print('Base model accuracy: ', score) + assert (0 <= score <= 1) + art_model.load_best_model_checkpoint() + score = art_model.score(PytorchData(x_test.astype(np.float32), y_test)) + print('best model accuracy: ', score) + assert (0 <= score <= 1)