diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index ef64006..b1d99ca 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -152,9 +152,17 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): self._optimizer.load_state_dict(checkpoint['opt_state_dict']) def load_latest_checkpoint(self): + """ + Load model only based on the check point path (latest.tar) + :return: loaded model + """ self.load_checkpoint_by_path('latest.tar') def load_best_checkpoint(self): + """ + Load model only based on the check point path (model_best.tar) + :return: loaded model + """ self.load_checkpoint_by_path('model_best.tar')