mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-05 09:42:37 +02:00
save checkpoints
This commit is contained in:
parent
521c8ce041
commit
7539ca0ead
1 changed files with 80 additions and 4 deletions
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
import shutil
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
from art.utils import check_and_transform_label_format
|
||||
from art.utils import check_and_transform_label_format, logger
|
||||
|
||||
from sklearn.preprocessing import OneHotEncoder
|
||||
|
||||
|
|
@ -24,7 +27,18 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
Wrapper class for pytorch classifier model.
|
||||
"""
|
||||
|
||||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None:
|
||||
def get_step_correct(self, outputs, targets) -> int:
|
||||
"""get number of correctly classified labels"""
|
||||
if len(outputs) != len(targets):
|
||||
raise ValueError("outputs and targets should be the same length.")
|
||||
counter = 0
|
||||
for i, o in enumerate(outputs):
|
||||
if o == targets[i]:
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True, **kwargs) -> None:
|
||||
"""
|
||||
Fit the classifier on the training set `(x, y)`.
|
||||
|
||||
|
|
@ -33,6 +47,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
shape (nb_samples,).
|
||||
:param batch_size: Size of batches.
|
||||
:param nb_epochs: Number of epochs to use for training.
|
||||
:param save_checkpoints: Boolean, save checkpoints if True.
|
||||
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
|
||||
and providing it takes no effect.
|
||||
"""
|
||||
|
|
@ -54,7 +69,11 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
ind = np.arange(len(x_preprocessed))
|
||||
|
||||
# Start training
|
||||
for _ in range(nb_epochs):
|
||||
for epoch in range(nb_epochs):
|
||||
tot_correct = 0
|
||||
total = 0
|
||||
val_acc = 0
|
||||
best_acc = 0
|
||||
# Shuffle the examples
|
||||
random.shuffle(ind)
|
||||
|
||||
|
|
@ -75,6 +94,63 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
loss.backward()
|
||||
|
||||
self._optimizer.step()
|
||||
correct = self.get_step_correct(model_outputs, o_batch)
|
||||
tot_correct += correct
|
||||
total += o_batch.shape[0]
|
||||
train_acc = float(tot_correct) / total
|
||||
|
||||
if save_checkpoints:
|
||||
additional_states = {'epoch': epoch + 1, 'acc': train_acc, 'best_acc': val_acc}
|
||||
self.save_checkpoint(is_best=best_acc <= val_acc, additional_states=additional_states)
|
||||
best_acc = max(val_acc, best_acc)
|
||||
|
||||
def save_checkpoint(self, is_best: bool, additional_states: Dict = None,
|
||||
filename="latest.tar") -> None:
|
||||
"""
|
||||
Saves checkpoint as latest.tar or best.tar
|
||||
:param is_best: whether the model is the best achieved model
|
||||
:param additional_states: additional parameters that will be saved with the model
|
||||
:param filename: checkpoint name
|
||||
:return: None
|
||||
"""
|
||||
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
|
||||
path = checkpoint
|
||||
filepath = os.path.join(path, filename)
|
||||
state = additional_states if additional_states else dict()
|
||||
state['state_dict'] = self.model.module.state_dict() \
|
||||
if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict()
|
||||
state['opt_state_dict'] = self.optimizer.state_dict()
|
||||
torch.save(state, filepath)
|
||||
logging.info("Saving {} model with validation acc {} and train acc {}".
|
||||
format('best' if is_best else 'checkpoint', state['best_acc'], state['acc']))
|
||||
if is_best:
|
||||
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
|
||||
|
||||
def load_checkpoint(self, model_name: str, path: str = None):
|
||||
"""
|
||||
Load model only based on the check point path
|
||||
:param model_name: check point filename
|
||||
:param path: checkpoint path (default current work dir)
|
||||
:return: loaded model
|
||||
"""
|
||||
if path is None:
|
||||
path = os.path.join(os.getcwd(), 'checkpoints')
|
||||
|
||||
filepath = os.path.join(path, model_name)
|
||||
if not os.path.exists(filepath):
|
||||
msg = f"Model file {filepath} not found"
|
||||
logger.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
else:
|
||||
checkpoint = torch.load(filepath)
|
||||
if isinstance(self._model, torch.nn.DataParallel):
|
||||
self._model.module.load_state_dict(checkpoint['state_dict'])
|
||||
else:
|
||||
self._model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
if self._optimizer and 'opt_state_dict' in checkpoint:
|
||||
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
|
||||
|
||||
|
||||
class PyTorchClassifier(PyTorchModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue