This commit is contained in:
olasaadi 2022-05-30 11:52:47 +03:00
parent 023f8764da
commit 8de77f9afd

View file

@ -152,9 +152,17 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
def load_latest_checkpoint(self):
"""
Load model only based on the check point path (latest.tar)
:return: loaded model
"""
self.load_checkpoint_by_path('latest.tar')
def load_best_checkpoint(self):
"""
Load model only based on the check point path (model_best.tar)
:return: loaded model
"""
self.load_checkpoint_by_path('model_best.tar')