mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-25 04:46:21 +02:00
update
This commit is contained in:
parent
8459d6961f
commit
023f8764da
2 changed files with 7 additions and 3 deletions
|
|
@ -125,7 +125,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
if is_best:
|
||||
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
|
||||
|
||||
def load_checkpoint(self, model_name: str, path: str = None):
|
||||
def load_checkpoint_by_path(self, model_name: str, path: str = None):
|
||||
"""
|
||||
Load model only based on the check point path
|
||||
:param model_name: check point filename
|
||||
|
|
@ -151,6 +151,12 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
if self._optimizer and 'opt_state_dict' in checkpoint:
|
||||
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
|
||||
|
||||
def load_latest_checkpoint(self):
|
||||
self.load_checkpoint_by_path('latest.tar')
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
self.load_checkpoint_by_path('model_best.tar')
|
||||
|
||||
|
||||
class PyTorchClassifier(PyTorchModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from apt.utils.datasets import ArrayDataset
|
||||
from apt.utils.models import ModelOutputType
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue