This commit is contained in:
olasaadi 2022-05-30 12:52:32 +03:00
parent 8de77f9afd
commit a3fb68fb56
2 changed files with 141 additions and 26 deletions

View file

@ -6,6 +6,7 @@ from typing import Optional, Tuple, Dict
import numpy as np
from art.utils import check_and_transform_label_format, logger
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
@ -15,6 +16,7 @@ from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
import torch
class PyTorchModel(Model):
"""
Wrapper class for pytorch models.
@ -35,13 +37,32 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
else:
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
def _eval(self, x: np.ndarray, y: np.ndarray):
self.model.eval()
total_loss = 0
correct = 0
total = 0
for m in range(len(x)):
inputs = torch.from_numpy(x[m]).to(self._device)
targets = torch.from_numpy(y[m]).to(self._device)
targets = targets.to(self.device)
outputs = self.model(inputs)
loss = self._loss(outputs, targets)
total_loss += (loss.item() * targets.size(0))
total += targets.size(0)
correct += self.get_step_correct(outputs, targets)
return total_loss / total, float(correct) / total
def fit(self, X: np.ndarray, Y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
save_checkpoints: bool = True, **kwargs) -> None:
"""
Fit the classifier on the training set `(x, y)`.
:param x: Training data.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
:param X: Training data.
:param Y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
shape (nb_samples,).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
@ -50,6 +71,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
and providing it takes no effect.
"""
# Put the model in the training mode
x, X_test, y, y_test = train_test_split(X, Y, test_size=0.33, random_state=42)
self._model.train()
if self._optimizer is None: # pragma: no cover
@ -95,15 +117,15 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
correct = self.get_step_correct(model_outputs[-1], o_batch)
tot_correct += correct
total += o_batch.shape[0]
val_loss, val_acc = self._eval(X_test, y_test)
train_acc = float(tot_correct) / total
if save_checkpoints:
additional_states = {'epoch': epoch + 1, 'acc': train_acc, 'best_acc': val_acc}
self.save_checkpoint(is_best=best_acc <= val_acc, additional_states=additional_states)
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
best_acc = max(val_acc, best_acc)
def save_checkpoint(self, is_best: bool, additional_states: Dict = None,
filename="latest.tar") -> None:
def save_checkpoint_state_dict(self, is_best: bool, additional_states: Dict = None,
filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
@ -120,12 +142,25 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict()
state['opt_state_dict'] = self.optimizer.state_dict()
torch.save(state, filepath)
logging.info("Saving {} model with validation acc {} and train acc {}".
format('best' if is_best else 'checkpoint', state['best_acc'], state['acc']))
if is_best:
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
def load_checkpoint_by_path(self, model_name: str, path: str = None):
def save_checkpoint_model(self, is_best: bool, filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
:param filename: checkpoint name
:return: None
"""
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
torch.save(self.model.module, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
@ -143,27 +178,57 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
else:
checkpoint = torch.load(filepath)
if isinstance(self._model, torch.nn.DataParallel):
self._model.module.load_state_dict(checkpoint['state_dict'])
else:
self._model.load_state_dict(checkpoint['state_dict'])
self.model.module.load_state_dict(checkpoint)
if self._optimizer and 'opt_state_dict' in checkpoint:
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
def load_latest_checkpoint(self):
def load_latest_state_dict_checkpoint(self):
"""
Load model only based on the check point path (latest.tar)
Load model state dict only based on the check point path (latest.tar)
:return: loaded model
"""
self.load_checkpoint_by_path('latest.tar')
self.load_checkpoint_state_dict_by_path('latest.tar')
def load_best_checkpoint(self):
def load_best_state_dict_checkpoint(self):
"""
Load model only based on the check point path (model_best.tar)
Load model state dict only based on the check point path (model_best.tar)
:return: loaded model
"""
self.load_checkpoint_by_path('model_best.tar')
self.load_checkpoint_state_dict_by_path('model_best.tar')
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
"""
if path is None:
path = os.path.join(os.getcwd(), 'checkpoints')
filepath = os.path.join(path, model_name)
if not os.path.exists(filepath):
msg = f"Model file {filepath} not found"
logger.error(msg)
raise FileNotFoundError(msg)
else:
self.model.module = torch.load(path)
def load_latest_model_checkpoint(self):
"""
Load entire model only based on the check point path (latest.tar)
:return: loaded model
"""
self.load_checkpoint_model_by_path('latest.tar')
def load_best_model_checkpoint(self):
"""
Load entire model only based on the check point path (model_best.tar)
:return: loaded model
"""
self.load_checkpoint_model_by_path('model_best.tar')
class PyTorchClassifier(PyTorchModel):
@ -227,3 +292,49 @@ class PyTorchClassifier(PyTorchModel):
y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes)
predicted = self.predict(test_data)
return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
"""
self._art_model.load_checkpoint_state_dict_by_path(model_name, path)
def load_latest_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (latest.tar)
:return: loaded model
"""
self._art_model.load_latest_state_dict_checkpoint()
def load_best_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (model_best.tar)
:return: loaded model
"""
self._art_model.load_best_state_dict_checkpoint()
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
"""
self._art_model.load_checkpoint_model_by_path(model_name, path)
def load_latest_model_checkpoint(self):
"""
Load entire model only based on the check point path (latest.tar)
:return: loaded model
"""
self._art_model.load_latest_model_checkpoint()
def load_best_model_checkpoint(self):
"""
Load entire model only based on the check point path (model_best.tar)
:return: loaded model
"""
self._art_model.load_best_model_checkpoint()