diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index b1d99ca..c4fbb08 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -6,6 +6,7 @@ from typing import Optional, Tuple, Dict import numpy as np from art.utils import check_and_transform_label_format, logger +from sklearn.model_selection import train_test_split from sklearn.preprocessing import OneHotEncoder @@ -15,6 +16,7 @@ from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier import torch + class PyTorchModel(Model): """ Wrapper class for pytorch models. @@ -35,13 +37,32 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): else: return int(torch.sum(torch.round(outputs, axis=-1) == targets).item()) - def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, + def _eval(self, x: np.ndarray, y: np.ndarray): + + self.model.eval() + + total_loss = 0 + correct = 0 + total = 0 + for m in range(len(x)): + inputs = torch.from_numpy(x[m]).to(self._device) + targets = torch.from_numpy(y[m]).to(self._device) + targets = targets.to(self.device) + outputs = self.model(inputs) + loss = self._loss(outputs, targets) + total_loss += (loss.item() * targets.size(0)) + total += targets.size(0) + correct += self.get_step_correct(outputs, targets) + + return total_loss / total, float(correct) / total + + def fit(self, X: np.ndarray, Y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, save_checkpoints: bool = True, **kwargs) -> None: """ Fit the classifier on the training set `(x, y)`. - :param x: Training data. - :param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of + :param X: Training data. + :param Y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of shape (nb_samples,). :param batch_size: Size of batches. :param nb_epochs: Number of epochs to use for training. @@ -50,6 +71,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): and providing it takes no effect. """ # Put the model in the training mode + x, X_test, y, y_test = train_test_split(X, Y, test_size=0.33, random_state=42) self._model.train() if self._optimizer is None: # pragma: no cover @@ -95,15 +117,15 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): correct = self.get_step_correct(model_outputs[-1], o_batch) tot_correct += correct total += o_batch.shape[0] + val_loss, val_acc = self._eval(X_test, y_test) train_acc = float(tot_correct) / total if save_checkpoints: - additional_states = {'epoch': epoch + 1, 'acc': train_acc, 'best_acc': val_acc} - self.save_checkpoint(is_best=best_acc <= val_acc, additional_states=additional_states) + self.save_checkpoint_state_dict(is_best=best_acc <= val_acc) best_acc = max(val_acc, best_acc) - def save_checkpoint(self, is_best: bool, additional_states: Dict = None, - filename="latest.tar") -> None: + def save_checkpoint_state_dict(self, is_best: bool, additional_states: Dict = None, + filename="latest.tar") -> None: """ Saves checkpoint as latest.tar or best.tar :param is_best: whether the model is the best achieved model @@ -120,12 +142,25 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict() state['opt_state_dict'] = self.optimizer.state_dict() torch.save(state, filepath) - logging.info("Saving {} model with validation acc {} and train acc {}". - format('best' if is_best else 'checkpoint', state['best_acc'], state['acc'])) if is_best: shutil.copyfile(filepath, os.path.join(path, 'model_best.tar')) - def load_checkpoint_by_path(self, model_name: str, path: str = None): + def save_checkpoint_model(self, is_best: bool, filename="latest.tar") -> None: + """ + Saves checkpoint as latest.tar or best.tar + :param is_best: whether the model is the best achieved model + :param filename: checkpoint name + :return: None + """ + checkpoint = os.path.join(os.getcwd(), 'checkpoints') + path = checkpoint + os.makedirs(path, exist_ok=True) + filepath = os.path.join(path, filename) + torch.save(self.model.module, filepath) + if is_best: + shutil.copyfile(filepath, os.path.join(path, 'model_best.tar')) + + def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None): """ Load model only based on the check point path :param model_name: check point filename @@ -143,27 +178,57 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): else: checkpoint = torch.load(filepath) - if isinstance(self._model, torch.nn.DataParallel): - self._model.module.load_state_dict(checkpoint['state_dict']) - else: - self._model.load_state_dict(checkpoint['state_dict']) + self.model.module.load_state_dict(checkpoint) if self._optimizer and 'opt_state_dict' in checkpoint: self._optimizer.load_state_dict(checkpoint['opt_state_dict']) - def load_latest_checkpoint(self): + def load_latest_state_dict_checkpoint(self): """ - Load model only based on the check point path (latest.tar) + Load model state dict only based on the check point path (latest.tar) :return: loaded model """ - self.load_checkpoint_by_path('latest.tar') + self.load_checkpoint_state_dict_by_path('latest.tar') - def load_best_checkpoint(self): + def load_best_state_dict_checkpoint(self): """ - Load model only based on the check point path (model_best.tar) + Load model state dict only based on the check point path (model_best.tar) :return: loaded model """ - self.load_checkpoint_by_path('model_best.tar') + self.load_checkpoint_state_dict_by_path('model_best.tar') + + def load_checkpoint_model_by_path(self, model_name: str, path: str = None): + """ + Load model only based on the check point path + :param model_name: check point filename + :param path: checkpoint path (default current work dir) + :return: loaded model + """ + if path is None: + path = os.path.join(os.getcwd(), 'checkpoints') + + filepath = os.path.join(path, model_name) + if not os.path.exists(filepath): + msg = f"Model file {filepath} not found" + logger.error(msg) + raise FileNotFoundError(msg) + + else: + self.model.module = torch.load(path) + + def load_latest_model_checkpoint(self): + """ + Load entire model only based on the check point path (latest.tar) + :return: loaded model + """ + self.load_checkpoint_model_by_path('latest.tar') + + def load_best_model_checkpoint(self): + """ + Load entire model only based on the check point path (model_best.tar) + :return: loaded model + """ + self.load_checkpoint_model_by_path('model_best.tar') class PyTorchClassifier(PyTorchModel): @@ -227,3 +292,49 @@ class PyTorchClassifier(PyTorchModel): y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes) predicted = self.predict(test_data) return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0] + + def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None): + """ + Load model only based on the check point path + :param model_name: check point filename + :param path: checkpoint path (default current work dir) + :return: loaded model + """ + self._art_model.load_checkpoint_state_dict_by_path(model_name, path) + + def load_latest_state_dict_checkpoint(self): + """ + Load model state dict only based on the check point path (latest.tar) + :return: loaded model + """ + self._art_model.load_latest_state_dict_checkpoint() + + def load_best_state_dict_checkpoint(self): + """ + Load model state dict only based on the check point path (model_best.tar) + :return: loaded model + """ + self._art_model.load_best_state_dict_checkpoint() + + def load_checkpoint_model_by_path(self, model_name: str, path: str = None): + """ + Load model only based on the check point path + :param model_name: check point filename + :param path: checkpoint path (default current work dir) + :return: loaded model + """ + self._art_model.load_checkpoint_model_by_path(model_name, path) + + def load_latest_model_checkpoint(self): + """ + Load entire model only based on the check point path (latest.tar) + :return: loaded model + """ + self._art_model.load_latest_model_checkpoint() + + def load_best_model_checkpoint(self): + """ + Load entire model only based on the check point path (model_best.tar) + :return: loaded model + """ + self._art_model.load_best_model_checkpoint() diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index dcb0be3..16e7d1f 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -48,16 +48,20 @@ def test_nursery_pytorch(): out = self.fc4(out) return self.classifier(out) - mlp_model = pytorch_model(4, 24) - mlp_model = torch.nn.DataParallel(mlp_model) + model = pytorch_model(4, 24) + model = torch.nn.DataParallel(model) criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(mlp_model.parameters(), lr=0.01) + optimizer = optim.Adam(model.parameters(), lr=0.01) - mlp_art_model = PyTorchClassifier(model=mlp_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, + art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, optimizer=optimizer, input_shape=(24,), nb_classes=4) - mlp_art_model.fit(ArrayDataset(x_train.astype(np.float32), y_train)) + art_model.fit(ArrayDataset(x_train.astype(np.float32), y_train)) - pred = np.array([np.argmax(arr) for arr in mlp_art_model.predict(ArrayDataset(x_test.astype(np.float32)))]) + pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))]) print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test)) + art_model.load_best_state_dict_checkpoint() + + +