diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 9659b7c..ec03f70 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -1,8 +1,11 @@ +import logging +import os import random -from typing import Optional, Tuple +import shutil +from typing import Optional, Tuple, Dict import numpy as np -from art.utils import check_and_transform_label_format +from art.utils import check_and_transform_label_format, logger from sklearn.preprocessing import OneHotEncoder @@ -24,7 +27,18 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): Wrapper class for pytorch classifier model. """ - def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None: + def get_step_correct(self, outputs, targets) -> int: + """get number of correctly classified labels""" + if len(outputs) != len(targets): + raise ValueError("outputs and targets should be the same length.") + counter = 0 + for i, o in enumerate(outputs): + if o == targets[i]: + counter += 1 + return counter + + def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, + save_checkpoints: bool = True, **kwargs) -> None: """ Fit the classifier on the training set `(x, y)`. @@ -33,6 +47,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): shape (nb_samples,). :param batch_size: Size of batches. :param nb_epochs: Number of epochs to use for training. + :param save_checkpoints: Boolean, save checkpoints if True. :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch and providing it takes no effect. """ @@ -54,7 +69,11 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): ind = np.arange(len(x_preprocessed)) # Start training - for _ in range(nb_epochs): + for epoch in range(nb_epochs): + tot_correct = 0 + total = 0 + val_acc = 0 + best_acc = 0 # Shuffle the examples random.shuffle(ind) @@ -75,6 +94,63 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): loss.backward() self._optimizer.step() + correct = self.get_step_correct(model_outputs, o_batch) + tot_correct += correct + total += o_batch.shape[0] + train_acc = float(tot_correct) / total + + if save_checkpoints: + additional_states = {'epoch': epoch + 1, 'acc': train_acc, 'best_acc': val_acc} + self.save_checkpoint(is_best=best_acc <= val_acc, additional_states=additional_states) + best_acc = max(val_acc, best_acc) + + def save_checkpoint(self, is_best: bool, additional_states: Dict = None, + filename="latest.tar") -> None: + """ + Saves checkpoint as latest.tar or best.tar + :param is_best: whether the model is the best achieved model + :param additional_states: additional parameters that will be saved with the model + :param filename: checkpoint name + :return: None + """ + checkpoint = os.path.join(os.getcwd(), 'checkpoints') + path = checkpoint + filepath = os.path.join(path, filename) + state = additional_states if additional_states else dict() + state['state_dict'] = self.model.module.state_dict() \ + if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict() + state['opt_state_dict'] = self.optimizer.state_dict() + torch.save(state, filepath) + logging.info("Saving {} model with validation acc {} and train acc {}". + format('best' if is_best else 'checkpoint', state['best_acc'], state['acc'])) + if is_best: + shutil.copyfile(filepath, os.path.join(path, 'model_best.tar')) + + def load_checkpoint(self, model_name: str, path: str = None): + """ + Load model only based on the check point path + :param model_name: check point filename + :param path: checkpoint path (default current work dir) + :return: loaded model + """ + if path is None: + path = os.path.join(os.getcwd(), 'checkpoints') + + filepath = os.path.join(path, model_name) + if not os.path.exists(filepath): + msg = f"Model file {filepath} not found" + logger.error(msg) + raise FileNotFoundError(msg) + + else: + checkpoint = torch.load(filepath) + if isinstance(self._model, torch.nn.DataParallel): + self._model.module.load_state_dict(checkpoint['state_dict']) + else: + self._model.load_state_dict(checkpoint['state_dict']) + + if self._optimizer and 'opt_state_dict' in checkpoint: + self._optimizer.load_state_dict(checkpoint['opt_state_dict']) class PyTorchClassifier(PyTorchModel):