This commit is contained in:
olasaadi 2022-06-06 14:32:34 +03:00
parent c954f53ad7
commit 21cba95a28
2 changed files with 43 additions and 8 deletions

View file

@ -65,7 +65,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
return total_loss / total, float(correct) / total
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
save_checkpoints: bool = True, **kwargs) -> None:
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
"""
Fit the classifier on the training set `(x, y)`.
:param x: Training data.
@ -74,6 +74,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param save_entire_model: Boolean, save entire model if True, else save state dict.
:param path: path for saving checkpoint.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
"""
@ -127,18 +129,22 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
train_acc = float(tot_correct) / total
best_acc = max(val_acc, best_acc)
if save_checkpoints:
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
if save_entire_model:
self.save_checkpoint_model(is_best=best_acc <= val_acc)
else:
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
def save_checkpoint_state_dict(self, is_best: bool,
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(),
filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
:param path: path for saving checkpoint
:param filename: checkpoint name
:return: None
"""
# add path
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
checkpoint = os.path.join(path, 'checkpoints')
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
@ -149,14 +155,15 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
if is_best:
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
def save_checkpoint_model(self, is_best: bool, filename="latest.tar") -> None:
def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
:param path: path for saving checkpoint
:param filename: checkpoint name
:return: None
"""
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
checkpoint = os.path.join(path, 'checkpoints')
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)