diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index ff33d48..ef64006 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -125,7 +125,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): if is_best: shutil.copyfile(filepath, os.path.join(path, 'model_best.tar')) - def load_checkpoint(self, model_name: str, path: str = None): + def load_checkpoint_by_path(self, model_name: str, path: str = None): """ Load model only based on the check point path :param model_name: check point filename @@ -151,6 +151,12 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): if self._optimizer and 'opt_state_dict' in checkpoint: self._optimizer.load_state_dict(checkpoint['opt_state_dict']) + def load_latest_checkpoint(self): + self.load_checkpoint_by_path('latest.tar') + + def load_best_checkpoint(self): + self.load_checkpoint_by_path('model_best.tar') + class PyTorchClassifier(PyTorchModel): """ diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index bdff6aa..dcb0be3 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -1,8 +1,6 @@ import numpy as np import torch from torch import nn, optim -from torch.utils.data import DataLoader -from torch.utils.data.dataset import Dataset from apt.utils.datasets import ArrayDataset from apt.utils.models import ModelOutputType