This commit is contained in:
olasaadi 2022-05-30 11:51:22 +03:00
parent 8459d6961f
commit 023f8764da
2 changed files with 7 additions and 3 deletions

View file

@ -125,7 +125,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
if is_best:
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
def load_checkpoint(self, model_name: str, path: str = None):
def load_checkpoint_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
@ -151,6 +151,12 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
if self._optimizer and 'opt_state_dict' in checkpoint:
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
def load_latest_checkpoint(self):
self.load_checkpoint_by_path('latest.tar')
def load_best_checkpoint(self):
self.load_checkpoint_by_path('model_best.tar')
class PyTorchClassifier(PyTorchModel):
"""

View file

@ -1,8 +1,6 @@
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from apt.utils.datasets import ArrayDataset
from apt.utils.models import ModelOutputType