mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
update pytorch wrapper to use torch loaders
fix tests and dataset style
This commit is contained in:
parent
fdc6005fce
commit
c77e34e373
4 changed files with 178 additions and 113 deletions
|
|
@ -39,7 +39,7 @@ def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
|
|||
if type(arr) == Tensor:
|
||||
return arr.detach().cpu().numpy()
|
||||
|
||||
raise ValueError('Non supported type: ', type(arr).__name__)
|
||||
raise ValueError("Non supported type: ", type(arr).__name__)
|
||||
|
||||
|
||||
def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
|
||||
|
|
@ -56,7 +56,7 @@ def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
|
|||
if type(arr) == Tensor:
|
||||
return arr
|
||||
|
||||
raise ValueError('Non supported type: ', type(arr).__name__)
|
||||
raise ValueError("Non supported type: ", type(arr).__name__)
|
||||
|
||||
|
||||
class Dataset(metaclass=ABCMeta):
|
||||
|
|
@ -109,7 +109,7 @@ class StoredDataset(Dataset):
|
|||
os.makedirs(dest_path, exist_ok=True)
|
||||
logger.info("Downloading the dataset...")
|
||||
urllib.request.urlretrieve(url, file_path)
|
||||
logger.info('Dataset Downloaded')
|
||||
logger.info("Dataset Downloaded")
|
||||
|
||||
if unzip:
|
||||
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
|
||||
|
|
@ -156,7 +156,7 @@ class StoredDataset(Dataset):
|
|||
logger.info("Shuffling data")
|
||||
np.random.shuffle(data)
|
||||
|
||||
debug_data = data[:int(len(data) * ratio)]
|
||||
debug_data = data[: int(len(data) * ratio)]
|
||||
logger.info(f"Saving {ratio} of the data to {dest_datafile}")
|
||||
np.savetxt(dest_datafile, debug_data, delimiter=delimiter, fmt=fmt)
|
||||
|
||||
|
|
@ -164,8 +164,13 @@ class StoredDataset(Dataset):
|
|||
class ArrayDataset(Dataset):
|
||||
"""Dataset that is based on x and y arrays (e.g., numpy/pandas/list...)"""
|
||||
|
||||
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None,
|
||||
features_names: Optional = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
x: INPUT_DATA_ARRAY_TYPE,
|
||||
y: Optional[INPUT_DATA_ARRAY_TYPE] = None,
|
||||
features_names: Optional = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
ArrayDataset constructor.
|
||||
:param x: collection of data samples
|
||||
|
|
@ -183,7 +188,7 @@ class ArrayDataset(Dataset):
|
|||
self.features_names = x.columns.to_list()
|
||||
|
||||
if y is not None and len(self._x) != len(self._y):
|
||||
raise ValueError('Non equivalent lengths of x and y')
|
||||
raise ValueError("Non equivalent lengths of x and y")
|
||||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return data samples as numpy array"""
|
||||
|
|
@ -195,7 +200,6 @@ class ArrayDataset(Dataset):
|
|||
|
||||
|
||||
class PytorchData(Dataset):
|
||||
|
||||
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
|
||||
"""
|
||||
PytorchData constructor.
|
||||
|
|
@ -210,15 +214,13 @@ class PytorchData(Dataset):
|
|||
self.features_names = x.columns
|
||||
|
||||
if y is not None and len(self._x) != len(self._y):
|
||||
raise ValueError('Non equivalent lengths of x and y')
|
||||
|
||||
raise ValueError("Non equivalent lengths of x and y")
|
||||
|
||||
if self._y is not None:
|
||||
self.__getitem__ = self.get_item
|
||||
else:
|
||||
self.__getitem__ = self.get_sample_item
|
||||
|
||||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return data samples as numpy array"""
|
||||
return array2numpy(self, self._x)
|
||||
|
|
@ -240,6 +242,7 @@ class PytorchData(Dataset):
|
|||
|
||||
class DatasetFactory:
|
||||
"""Factory class for dataset creation"""
|
||||
|
||||
registry = {}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -252,7 +255,7 @@ class DatasetFactory:
|
|||
|
||||
def inner_wrapper(wrapped_class: Dataset) -> Any:
|
||||
if name in cls.registry:
|
||||
logger.warning('Dataset %s already exists. Will replace it', name)
|
||||
logger.warning("Dataset %s already exists. Will replace it", name)
|
||||
cls.registry[name] = wrapped_class
|
||||
return wrapped_class
|
||||
|
||||
|
|
@ -270,7 +273,7 @@ class DatasetFactory:
|
|||
:return: An instance of the dataset that is created.
|
||||
"""
|
||||
if name not in cls.registry:
|
||||
msg = f'Dataset {name} does not exist in the registry'
|
||||
msg = f"Dataset {name} does not exist in the registry"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,21 @@
|
|||
import logging
|
||||
""" Pytorch Model Wrapper"""
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from art.utils import check_and_transform_label_format, logger
|
||||
from apt.utils.datasets.datasets import PytorchData
|
||||
from apt.utils.models import Model, ModelOutputType
|
||||
from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE
|
||||
|
||||
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PyTorchModel(Model):
|
||||
|
|
@ -22,9 +26,9 @@ class PyTorchModel(Model):
|
|||
|
||||
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
||||
"""
|
||||
Wrapper class for pytorch classifier model.
|
||||
Extension for Pytorch ART model
|
||||
"""
|
||||
Wrapper class for pytorch classifier model.
|
||||
Extension for Pytorch ART model
|
||||
"""
|
||||
|
||||
def get_step_correct(self, outputs, targets) -> int:
|
||||
"""get number of correctly classified labels"""
|
||||
|
|
@ -35,43 +39,47 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
else:
|
||||
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
|
||||
|
||||
def _eval(self, x: np.ndarray, y: np.ndarray, nb_epochs, batch_size):
|
||||
|
||||
def _eval(self, loader: DataLoader):
|
||||
"""inner function for model evaluation"""
|
||||
self.model.eval()
|
||||
|
||||
total_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
y = check_and_transform_label_format(y, self.nb_classes)
|
||||
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
|
||||
y_preprocessed = self.reduce_labels(y_preprocessed)
|
||||
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
|
||||
ind = np.arange(len(x_preprocessed))
|
||||
for epoch in range(nb_epochs):
|
||||
random.shuffle(ind)
|
||||
for m in range(num_batch):
|
||||
inputs = torch.from_numpy(x_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
|
||||
targets = torch.from_numpy(y_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
|
||||
targets = targets.to(self.device)
|
||||
outputs = self.model(inputs)
|
||||
loss = self._loss(outputs, targets)
|
||||
total_loss += (loss.item() * targets.size(0))
|
||||
total += targets.size(0)
|
||||
correct += self.get_step_correct(outputs, targets)
|
||||
|
||||
for inputs, targets in loader:
|
||||
inputs = inputs.to(self._device)
|
||||
targets = targets.to(self._device)
|
||||
|
||||
outputs = self.model(inputs)
|
||||
loss = self._loss(outputs, targets)
|
||||
total_loss += loss.item() * targets.size(0)
|
||||
total += targets.size(0)
|
||||
correct += self.get_step_correct(outputs, targets)
|
||||
|
||||
return total_loss / total, float(correct) / total
|
||||
|
||||
def fit(self, x: np.ndarray, y: np.ndarray, x_validation: np.ndarray = None, y_validation: np.ndarray = None,
|
||||
batch_size: int = 128, nb_epochs: int = 10, save_checkpoints: bool = True, save_entire_model=True,
|
||||
path=os.getcwd(), **kwargs) -> None:
|
||||
def fit(
|
||||
self,
|
||||
x: np.ndarray,
|
||||
y: np.ndarray,
|
||||
x_validation: np.ndarray = None,
|
||||
y_validation: np.ndarray = None,
|
||||
batch_size: int = 128,
|
||||
nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True,
|
||||
save_entire_model=True,
|
||||
path=os.getcwd(),
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Fit the classifier on the training set `(x, y)`.
|
||||
:param x: Training data.
|
||||
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
|
||||
of shape (nb_samples,).
|
||||
:param x_validation: Validation data (optional).
|
||||
:param y_validation: Target validation values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
|
||||
of shape (nb_samples,) (optional).
|
||||
:param y_validation: Target validation values (class labels) one-hot-encoded of shape
|
||||
(nb_samples, nb_classes) or index labels of shape (nb_samples,) (optional).
|
||||
:param batch_size: Size of batches.
|
||||
:param nb_epochs: Number of epochs to use for training.
|
||||
:param save_checkpoints: Boolean, save checkpoints if True.
|
||||
|
|
@ -87,18 +95,24 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
raise ValueError("An optimizer is needed to train the model, but none for provided.")
|
||||
|
||||
_y = check_and_transform_label_format(y, self.nb_classes)
|
||||
if x_validation is None or y_validation is None:
|
||||
x_validation = x
|
||||
y_validation = y
|
||||
|
||||
# Apply preprocessing
|
||||
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True)
|
||||
|
||||
# Check label shape
|
||||
y_preprocessed = self.reduce_labels(y_preprocessed)
|
||||
|
||||
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
|
||||
ind = np.arange(len(x_preprocessed))
|
||||
train_dataset = TensorDataset(torch.from_numpy(x_preprocessed), torch.from_numpy(y_preprocessed))
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
if x_validation is None or y_validation is None:
|
||||
val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
|
||||
logger.info("Using train set for validation")
|
||||
else:
|
||||
_y_val = check_and_transform_label_format(y_validation, self.nb_classes)
|
||||
x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, _y_val, fit=False)
|
||||
# Check label shape
|
||||
y_val_preprocessed = self.reduce_labels(y_val_preprocessed)
|
||||
val_dataset = TensorDataset(torch.from_numpy(x_val_preprocessed), torch.from_numpy(y_val_preprocessed))
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
|
||||
|
||||
# Start training
|
||||
for epoch in range(nb_epochs):
|
||||
|
|
@ -106,40 +120,38 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
total = 0
|
||||
best_acc = 0
|
||||
# Shuffle the examples
|
||||
random.shuffle(ind)
|
||||
|
||||
# Train for one epoch
|
||||
for m in range(num_batch):
|
||||
i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
|
||||
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
|
||||
|
||||
for inputs, targets in train_loader:
|
||||
inputs = inputs.to(self._device)
|
||||
targets = targets.to(self._device)
|
||||
# Zero the parameter gradients
|
||||
self._optimizer.zero_grad()
|
||||
|
||||
# Perform prediction
|
||||
model_outputs = self._model(i_batch)
|
||||
model_outputs = self._model(inputs)
|
||||
|
||||
# Form the loss function
|
||||
loss = self._loss(model_outputs[-1], o_batch)
|
||||
loss = self._loss(model_outputs[-1], targets)
|
||||
|
||||
loss.backward()
|
||||
|
||||
self._optimizer.step()
|
||||
correct = self.get_step_correct(model_outputs[-1], o_batch)
|
||||
correct = self.get_step_correct(model_outputs[-1], targets)
|
||||
tot_correct += correct
|
||||
total += o_batch.shape[0]
|
||||
total += targets.shape[0]
|
||||
|
||||
val_loss, val_acc = self._eval(val_loader)
|
||||
logger.info(f"Epoch{epoch + 1}/{nb_epochs} Val_Loss: {val_loss}, Val_Acc: {val_acc}")
|
||||
|
||||
val_loss, val_acc = self._eval(x_validation, y_validation, num_batch, batch_size)
|
||||
print(val_acc)
|
||||
best_acc = max(val_acc, best_acc)
|
||||
if save_checkpoints:
|
||||
if save_entire_model:
|
||||
self.save_checkpoint_model(is_best=best_acc <= val_acc)
|
||||
self.save_checkpoint_model(is_best=best_acc <= val_acc, path=path)
|
||||
else:
|
||||
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
|
||||
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc, path=path)
|
||||
|
||||
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(),
|
||||
filename="latest.tar") -> None:
|
||||
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
|
||||
"""
|
||||
Saves checkpoint as latest.tar or best.tar
|
||||
:param is_best: whether the model is the best achieved model
|
||||
|
|
@ -148,16 +160,19 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
:return: None
|
||||
"""
|
||||
# add path
|
||||
checkpoint = os.path.join(path, 'checkpoints')
|
||||
checkpoint = os.path.join(path, "checkpoints")
|
||||
path = checkpoint
|
||||
os.makedirs(path, exist_ok=True)
|
||||
filepath = os.path.join(path, filename)
|
||||
state = dict()
|
||||
state['state_dict'] = self.model.state_dict()
|
||||
state['opt_state_dict'] = self.optimizer.state_dict()
|
||||
state["state_dict"] = self.model.state_dict()
|
||||
state["opt_state_dict"] = self.optimizer.state_dict()
|
||||
|
||||
logger.info(f"Saving checkpoint state dictionary: {filepath}")
|
||||
torch.save(state, filepath)
|
||||
if is_best:
|
||||
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
|
||||
shutil.copyfile(filepath, os.path.join(path, "model_best.tar"))
|
||||
logger.info(f"Saving best state dictionary checkpoint: {os.path.join(path, 'model_best.tar')}")
|
||||
|
||||
def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
|
||||
"""
|
||||
|
|
@ -167,13 +182,15 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
:param filename: checkpoint name
|
||||
:return: None
|
||||
"""
|
||||
checkpoint = os.path.join(path, 'checkpoints')
|
||||
checkpoint = os.path.join(path, "checkpoints")
|
||||
path = checkpoint
|
||||
os.makedirs(path, exist_ok=True)
|
||||
filepath = os.path.join(path, filename)
|
||||
logger.info(f"Saving checkpoint model : {filepath}")
|
||||
torch.save(self.model, filepath)
|
||||
if is_best:
|
||||
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
|
||||
shutil.copyfile(filepath, os.path.join(path, "model_best.tar"))
|
||||
logger.info(f"Saving best checkpoint model: {os.path.join(path, 'model_best.tar')}")
|
||||
|
||||
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
|
||||
"""
|
||||
|
|
@ -183,7 +200,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
:return: loaded model
|
||||
"""
|
||||
if path is None:
|
||||
path = os.path.join(os.getcwd(), 'checkpoints')
|
||||
path = os.path.join(os.getcwd(), "checkpoints")
|
||||
|
||||
filepath = os.path.join(path, model_name)
|
||||
if not os.path.exists(filepath):
|
||||
|
|
@ -193,25 +210,26 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
else:
|
||||
checkpoint = torch.load(filepath)
|
||||
self.model.load_state_dict(checkpoint['state_dict'])
|
||||
self.model.load_state_dict(checkpoint["state_dict"])
|
||||
self.model.to(self.device)
|
||||
|
||||
if self._optimizer and 'opt_state_dict' in checkpoint:
|
||||
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
|
||||
if self._optimizer and "opt_state_dict" in checkpoint:
|
||||
self._optimizer.load_state_dict(checkpoint["opt_state_dict"])
|
||||
self.model.eval()
|
||||
|
||||
def load_latest_state_dict_checkpoint(self):
|
||||
"""
|
||||
Load model state dict only based on the check point path (latest.tar)
|
||||
:return: loaded model
|
||||
Load model state dict only based on the check point path (latest.tar)
|
||||
:return: loaded model
|
||||
"""
|
||||
self.load_checkpoint_state_dict_by_path('latest.tar')
|
||||
self.load_checkpoint_state_dict_by_path("latest.tar")
|
||||
|
||||
def load_best_state_dict_checkpoint(self):
|
||||
"""
|
||||
Load model state dict only based on the check point path (model_best.tar)
|
||||
:return: loaded model
|
||||
Load model state dict only based on the check point path (model_best.tar)
|
||||
:return: loaded model
|
||||
"""
|
||||
self.load_checkpoint_state_dict_by_path('model_best.tar')
|
||||
self.load_checkpoint_state_dict_by_path("model_best.tar")
|
||||
|
||||
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
|
||||
"""
|
||||
|
|
@ -221,7 +239,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
:return: loaded model
|
||||
"""
|
||||
if path is None:
|
||||
path = os.path.join(os.getcwd(), 'checkpoints')
|
||||
path = os.path.join(os.getcwd(), "checkpoints")
|
||||
|
||||
filepath = os.path.join(path, model_name)
|
||||
if not os.path.exists(filepath):
|
||||
|
|
@ -230,22 +248,23 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
raise FileNotFoundError(msg)
|
||||
|
||||
else:
|
||||
self._model._model = torch.load(filepath)
|
||||
self._model._model = torch.load(filepath, map_location=self.device)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def load_latest_model_checkpoint(self):
|
||||
"""
|
||||
Load entire model only based on the check point path (latest.tar)
|
||||
:return: loaded model
|
||||
Load entire model only based on the check point path (latest.tar)
|
||||
:return: loaded model
|
||||
"""
|
||||
self.load_checkpoint_model_by_path('latest.tar')
|
||||
self.load_checkpoint_model_by_path("latest.tar")
|
||||
|
||||
def load_best_model_checkpoint(self):
|
||||
"""
|
||||
Load entire model only based on the check point path (model_best.tar)
|
||||
:return: loaded model
|
||||
Load entire model only based on the check point path (model_best.tar)
|
||||
:return: loaded model
|
||||
"""
|
||||
self.load_checkpoint_model_by_path('model_best.tar')
|
||||
self.load_checkpoint_model_by_path("model_best.tar")
|
||||
|
||||
|
||||
class PyTorchClassifier(PyTorchModel):
|
||||
|
|
@ -253,9 +272,18 @@ class PyTorchClassifier(PyTorchModel):
|
|||
Wrapper class for pytorch classification models.
|
||||
"""
|
||||
|
||||
def __init__(self, model: "torch.nn.Module", output_type: ModelOutputType, loss: "torch.nn.modules.loss._Loss",
|
||||
input_shape: Tuple[int, ...], nb_classes: int, optimizer: "torch.optim.Optimizer",
|
||||
black_box_access: Optional[bool] = True, unlimited_queries: Optional[bool] = True, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model: "torch.nn.Module",
|
||||
output_type: ModelOutputType,
|
||||
loss: "torch.nn.modules.loss._Loss",
|
||||
input_shape: Tuple[int, ...],
|
||||
nb_classes: int,
|
||||
optimizer: "torch.optim.Optimizer",
|
||||
black_box_access: Optional[bool] = True,
|
||||
unlimited_queries: Optional[bool] = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialization specifically for the PyTorch-based implementation.
|
||||
|
||||
|
|
@ -278,9 +306,17 @@ class PyTorchClassifier(PyTorchModel):
|
|||
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
|
||||
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
|
||||
|
||||
def fit(self, train_data: PytorchData, validation_data: PytorchData = None, batch_size: int = 128,
|
||||
nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
|
||||
def fit(
|
||||
self,
|
||||
train_data: PytorchData,
|
||||
validation_data: PytorchData = None,
|
||||
batch_size: int = 128,
|
||||
nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True,
|
||||
save_entire_model=True,
|
||||
path=os.getcwd(),
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Fit the model using the training data.
|
||||
|
||||
|
|
@ -296,9 +332,30 @@ class PyTorchClassifier(PyTorchModel):
|
|||
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
|
||||
supported for PyTorch and providing it takes no effect.
|
||||
"""
|
||||
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1),
|
||||
validation_data.get_samples(), validation_data.get_labels().reshape(-1, 1), batch_size,
|
||||
nb_epochs, save_checkpoints, save_entire_model, path, **kwargs)
|
||||
if validation_data is None:
|
||||
self._art_model.fit(
|
||||
x=train_data.get_samples(),
|
||||
y=train_data.get_labels().reshape(-1, 1),
|
||||
batch_size=batch_size,
|
||||
nb_epochs=nb_epochs,
|
||||
save_checkpoints=save_checkpoints,
|
||||
save_entire_model=save_entire_model,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self._art_model.fit(
|
||||
x=train_data.get_samples(),
|
||||
y=train_data.get_labels().reshape(-1, 1),
|
||||
x_validation=validation_data.get_samples(),
|
||||
y_validation=validation_data.get_labels().reshape(-1, 1),
|
||||
batch_size=batch_size,
|
||||
nb_epochs=nb_epochs,
|
||||
save_checkpoints=save_checkpoints,
|
||||
save_entire_model=save_entire_model,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def predict(self, x: PytorchData, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""
|
||||
|
|
@ -332,15 +389,15 @@ class PyTorchClassifier(PyTorchModel):
|
|||
|
||||
def load_latest_state_dict_checkpoint(self):
|
||||
"""
|
||||
Load model state dict only based on the check point path (latest.tar)
|
||||
:return: loaded model
|
||||
Load model state dict only based on the check point path (latest.tar)
|
||||
:return: loaded model
|
||||
"""
|
||||
self._art_model.load_latest_state_dict_checkpoint()
|
||||
|
||||
def load_best_state_dict_checkpoint(self):
|
||||
"""
|
||||
Load model state dict only based on the check point path (model_best.tar)
|
||||
:return: loaded model
|
||||
Load model state dict only based on the check point path (model_best.tar)
|
||||
:return: loaded model
|
||||
"""
|
||||
self._art_model.load_best_state_dict_checkpoint()
|
||||
|
||||
|
|
@ -355,14 +412,14 @@ class PyTorchClassifier(PyTorchModel):
|
|||
|
||||
def load_latest_model_checkpoint(self):
|
||||
"""
|
||||
Load entire model only based on the check point path (latest.tar)
|
||||
:return: loaded model
|
||||
Load entire model only based on the check point path (latest.tar)
|
||||
:return: loaded model
|
||||
"""
|
||||
self._art_model.load_latest_model_checkpoint()
|
||||
|
||||
def load_best_model_checkpoint(self):
|
||||
"""
|
||||
Load entire model only based on the check point path (model_best.tar)
|
||||
:return: loaded model
|
||||
Load entire model only based on the check point path (model_best.tar)
|
||||
:return: loaded model
|
||||
"""
|
||||
self._art_model.load_best_model_checkpoint()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
numpy==1.21.0
|
||||
pandas==1.1.0
|
||||
scipy==1.4.1
|
||||
scikit-learn==0.22.2
|
||||
numpy~=1.22.3
|
||||
pandas~=1.1.3
|
||||
scipy~=1.5.2
|
||||
scikit-learn~=1.0.2
|
||||
adversarial-robustness-toolkit>=1.9.1
|
||||
|
||||
# testing
|
||||
pytest==5.4.2
|
||||
pytest~=6.1.1
|
||||
|
||||
torch~=1.11.0
|
||||
sklearn~=0.0
|
||||
six~=1.15.0
|
||||
shap~=0.40.0
|
||||
|
|
@ -57,7 +57,7 @@ def test_nursery_pytorch_state_dict():
|
|||
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=1000)
|
||||
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
|
||||
model.load_latest_state_dict_checkpoint()
|
||||
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||
print('Base model accuracy: ', score)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue