This commit is contained in:
olasaadi 2022-06-06 14:02:40 +03:00
parent 302d0c4b8c
commit c954f53ad7
2 changed files with 42 additions and 42 deletions

View file

@ -83,10 +83,10 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
if self._optimizer is None: # pragma: no cover
raise ValueError("An optimizer is needed to train the model, but none for provided.")
y = check_and_transform_label_format(y, self.nb_classes)
_y = check_and_transform_label_format(y, self.nb_classes)
# Apply preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True)
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)
@ -127,14 +127,13 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
train_acc = float(tot_correct) / total
best_acc = max(val_acc, best_acc)
if save_checkpoints:
self.save_checkpoint_model(is_best=best_acc <= val_acc)
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
def save_checkpoint_state_dict(self, is_best: bool, additional_states: Dict = None,
def save_checkpoint_state_dict(self, is_best: bool,
filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
:param additional_states: additional parameters that will be saved with the model
:param filename: checkpoint name
:return: None
"""
@ -143,9 +142,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
state = additional_states if additional_states else dict()
state['state_dict'] = self.model.module.state_dict() \
if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict()
state = dict()
state['state_dict'] = self.model.state_dict()
state['opt_state_dict'] = self.optimizer.state_dict()
torch.save(state, filepath)
if is_best:
@ -162,7 +160,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
torch.save(self.model.module, filepath)
torch.save(self.model, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
@ -184,7 +182,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
else:
checkpoint = torch.load(filepath)
self.model.module.load_state_dict(checkpoint)
self.model.load_state_dict(checkpoint['state_dict'])
if self._optimizer and 'opt_state_dict' in checkpoint:
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
@ -220,7 +218,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
raise FileNotFoundError(msg)
else:
self.model.module = torch.load(path)
self._model = torch.load(filepath)
def load_latest_model_checkpoint(self):
"""