save checkpoints

This commit is contained in:
olasaadi 2022-05-19 04:29:55 +03:00
parent 521c8ce041
commit 7539ca0ead

View file

@ -1,8 +1,11 @@
import logging
import os
import random
from typing import Optional, Tuple
import shutil
from typing import Optional, Tuple, Dict
import numpy as np
from art.utils import check_and_transform_label_format
from art.utils import check_and_transform_label_format, logger
from sklearn.preprocessing import OneHotEncoder
@ -24,7 +27,18 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
Wrapper class for pytorch classifier model.
"""
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None:
def get_step_correct(self, outputs, targets) -> int:
"""get number of correctly classified labels"""
if len(outputs) != len(targets):
raise ValueError("outputs and targets should be the same length.")
counter = 0
for i, o in enumerate(outputs):
if o == targets[i]:
counter += 1
return counter
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
save_checkpoints: bool = True, **kwargs) -> None:
"""
Fit the classifier on the training set `(x, y)`.
@ -33,6 +47,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
shape (nb_samples,).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
and providing it takes no effect.
"""
@ -54,7 +69,11 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
ind = np.arange(len(x_preprocessed))
# Start training
for _ in range(nb_epochs):
for epoch in range(nb_epochs):
tot_correct = 0
total = 0
val_acc = 0
best_acc = 0
# Shuffle the examples
random.shuffle(ind)
@ -75,6 +94,63 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
loss.backward()
self._optimizer.step()
correct = self.get_step_correct(model_outputs, o_batch)
tot_correct += correct
total += o_batch.shape[0]
train_acc = float(tot_correct) / total
if save_checkpoints:
additional_states = {'epoch': epoch + 1, 'acc': train_acc, 'best_acc': val_acc}
self.save_checkpoint(is_best=best_acc <= val_acc, additional_states=additional_states)
best_acc = max(val_acc, best_acc)
def save_checkpoint(self, is_best: bool, additional_states: Dict = None,
filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
:param additional_states: additional parameters that will be saved with the model
:param filename: checkpoint name
:return: None
"""
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
path = checkpoint
filepath = os.path.join(path, filename)
state = additional_states if additional_states else dict()
state['state_dict'] = self.model.module.state_dict() \
if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict()
state['opt_state_dict'] = self.optimizer.state_dict()
torch.save(state, filepath)
logging.info("Saving {} model with validation acc {} and train acc {}".
format('best' if is_best else 'checkpoint', state['best_acc'], state['acc']))
if is_best:
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
def load_checkpoint(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
"""
if path is None:
path = os.path.join(os.getcwd(), 'checkpoints')
filepath = os.path.join(path, model_name)
if not os.path.exists(filepath):
msg = f"Model file {filepath} not found"
logger.error(msg)
raise FileNotFoundError(msg)
else:
checkpoint = torch.load(filepath)
if isinstance(self._model, torch.nn.DataParallel):
self._model.module.load_state_dict(checkpoint['state_dict'])
else:
self._model.load_state_dict(checkpoint['state_dict'])
if self._optimizer and 'opt_state_dict' in checkpoint:
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
class PyTorchClassifier(PyTorchModel):