update pytorch wrapper to use torch loaders

fix tests
and dataset style
This commit is contained in:
Ron Shmelkin 2022-07-24 14:31:47 +03:00
parent fdc6005fce
commit c77e34e373
No known key found for this signature in database
GPG key ID: A4289A6607B5C294
4 changed files with 178 additions and 113 deletions

View file

@ -39,7 +39,7 @@ def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
if type(arr) == Tensor:
return arr.detach().cpu().numpy()
raise ValueError('Non supported type: ', type(arr).__name__)
raise ValueError("Non supported type: ", type(arr).__name__)
def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
@ -56,7 +56,7 @@ def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
if type(arr) == Tensor:
return arr
raise ValueError('Non supported type: ', type(arr).__name__)
raise ValueError("Non supported type: ", type(arr).__name__)
class Dataset(metaclass=ABCMeta):
@ -109,7 +109,7 @@ class StoredDataset(Dataset):
os.makedirs(dest_path, exist_ok=True)
logger.info("Downloading the dataset...")
urllib.request.urlretrieve(url, file_path)
logger.info('Dataset Downloaded')
logger.info("Dataset Downloaded")
if unzip:
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
@ -156,7 +156,7 @@ class StoredDataset(Dataset):
logger.info("Shuffling data")
np.random.shuffle(data)
debug_data = data[:int(len(data) * ratio)]
debug_data = data[: int(len(data) * ratio)]
logger.info(f"Saving {ratio} of the data to {dest_datafile}")
np.savetxt(dest_datafile, debug_data, delimiter=delimiter, fmt=fmt)
@ -164,8 +164,13 @@ class StoredDataset(Dataset):
class ArrayDataset(Dataset):
"""Dataset that is based on x and y arrays (e.g., numpy/pandas/list...)"""
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None,
features_names: Optional = None, **kwargs):
def __init__(
self,
x: INPUT_DATA_ARRAY_TYPE,
y: Optional[INPUT_DATA_ARRAY_TYPE] = None,
features_names: Optional = None,
**kwargs,
):
"""
ArrayDataset constructor.
:param x: collection of data samples
@ -183,7 +188,7 @@ class ArrayDataset(Dataset):
self.features_names = x.columns.to_list()
if y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')
raise ValueError("Non equivalent lengths of x and y")
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
@ -195,7 +200,6 @@ class ArrayDataset(Dataset):
class PytorchData(Dataset):
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
"""
PytorchData constructor.
@ -210,15 +214,13 @@ class PytorchData(Dataset):
self.features_names = x.columns
if y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')
raise ValueError("Non equivalent lengths of x and y")
if self._y is not None:
self.__getitem__ = self.get_item
else:
self.__getitem__ = self.get_sample_item
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return array2numpy(self, self._x)
@ -240,6 +242,7 @@ class PytorchData(Dataset):
class DatasetFactory:
"""Factory class for dataset creation"""
registry = {}
@classmethod
@ -252,7 +255,7 @@ class DatasetFactory:
def inner_wrapper(wrapped_class: Dataset) -> Any:
if name in cls.registry:
logger.warning('Dataset %s already exists. Will replace it', name)
logger.warning("Dataset %s already exists. Will replace it", name)
cls.registry[name] = wrapped_class
return wrapped_class
@ -270,7 +273,7 @@ class DatasetFactory:
:return: An instance of the dataset that is created.
"""
if name not in cls.registry:
msg = f'Dataset {name} does not exist in the registry'
msg = f"Dataset {name} does not exist in the registry"
logger.error(msg)
raise ValueError(msg)

View file

@ -1,17 +1,21 @@
import logging
""" Pytorch Model Wrapper"""
import os
import random
import shutil
from typing import Optional, Tuple
import logging
from typing import Optional, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from art.utils import check_and_transform_label_format, logger
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import Model, ModelOutputType
from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
import torch
logger = logging.getLogger(__name__)
class PyTorchModel(Model):
@ -22,9 +26,9 @@ class PyTorchModel(Model):
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
Wrapper class for pytorch classifier model.
Extension for Pytorch ART model
"""
Wrapper class for pytorch classifier model.
Extension for Pytorch ART model
"""
def get_step_correct(self, outputs, targets) -> int:
"""get number of correctly classified labels"""
@ -35,43 +39,47 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
else:
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
def _eval(self, x: np.ndarray, y: np.ndarray, nb_epochs, batch_size):
def _eval(self, loader: DataLoader):
"""inner function for model evaluation"""
self.model.eval()
total_loss = 0
correct = 0
total = 0
y = check_and_transform_label_format(y, self.nb_classes)
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
y_preprocessed = self.reduce_labels(y_preprocessed)
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
ind = np.arange(len(x_preprocessed))
for epoch in range(nb_epochs):
random.shuffle(ind)
for m in range(num_batch):
inputs = torch.from_numpy(x_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
targets = torch.from_numpy(y_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
targets = targets.to(self.device)
outputs = self.model(inputs)
loss = self._loss(outputs, targets)
total_loss += (loss.item() * targets.size(0))
total += targets.size(0)
correct += self.get_step_correct(outputs, targets)
for inputs, targets in loader:
inputs = inputs.to(self._device)
targets = targets.to(self._device)
outputs = self.model(inputs)
loss = self._loss(outputs, targets)
total_loss += loss.item() * targets.size(0)
total += targets.size(0)
correct += self.get_step_correct(outputs, targets)
return total_loss / total, float(correct) / total
def fit(self, x: np.ndarray, y: np.ndarray, x_validation: np.ndarray = None, y_validation: np.ndarray = None,
batch_size: int = 128, nb_epochs: int = 10, save_checkpoints: bool = True, save_entire_model=True,
path=os.getcwd(), **kwargs) -> None:
def fit(
self,
x: np.ndarray,
y: np.ndarray,
x_validation: np.ndarray = None,
y_validation: np.ndarray = None,
batch_size: int = 128,
nb_epochs: int = 10,
save_checkpoints: bool = True,
save_entire_model=True,
path=os.getcwd(),
**kwargs,
) -> None:
"""
Fit the classifier on the training set `(x, y)`.
:param x: Training data.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
of shape (nb_samples,).
:param x_validation: Validation data (optional).
:param y_validation: Target validation values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
of shape (nb_samples,) (optional).
:param y_validation: Target validation values (class labels) one-hot-encoded of shape
(nb_samples, nb_classes) or index labels of shape (nb_samples,) (optional).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
@ -87,18 +95,24 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
raise ValueError("An optimizer is needed to train the model, but none for provided.")
_y = check_and_transform_label_format(y, self.nb_classes)
if x_validation is None or y_validation is None:
x_validation = x
y_validation = y
# Apply preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True)
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
ind = np.arange(len(x_preprocessed))
train_dataset = TensorDataset(torch.from_numpy(x_preprocessed), torch.from_numpy(y_preprocessed))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
if x_validation is None or y_validation is None:
val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
logger.info("Using train set for validation")
else:
_y_val = check_and_transform_label_format(y_validation, self.nb_classes)
x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, _y_val, fit=False)
# Check label shape
y_val_preprocessed = self.reduce_labels(y_val_preprocessed)
val_dataset = TensorDataset(torch.from_numpy(x_val_preprocessed), torch.from_numpy(y_val_preprocessed))
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
# Start training
for epoch in range(nb_epochs):
@ -106,40 +120,38 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
total = 0
best_acc = 0
# Shuffle the examples
random.shuffle(ind)
# Train for one epoch
for m in range(num_batch):
i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
for inputs, targets in train_loader:
inputs = inputs.to(self._device)
targets = targets.to(self._device)
# Zero the parameter gradients
self._optimizer.zero_grad()
# Perform prediction
model_outputs = self._model(i_batch)
model_outputs = self._model(inputs)
# Form the loss function
loss = self._loss(model_outputs[-1], o_batch)
loss = self._loss(model_outputs[-1], targets)
loss.backward()
self._optimizer.step()
correct = self.get_step_correct(model_outputs[-1], o_batch)
correct = self.get_step_correct(model_outputs[-1], targets)
tot_correct += correct
total += o_batch.shape[0]
total += targets.shape[0]
val_loss, val_acc = self._eval(val_loader)
logger.info(f"Epoch{epoch + 1}/{nb_epochs} Val_Loss: {val_loss}, Val_Acc: {val_acc}")
val_loss, val_acc = self._eval(x_validation, y_validation, num_batch, batch_size)
print(val_acc)
best_acc = max(val_acc, best_acc)
if save_checkpoints:
if save_entire_model:
self.save_checkpoint_model(is_best=best_acc <= val_acc)
self.save_checkpoint_model(is_best=best_acc <= val_acc, path=path)
else:
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc, path=path)
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(),
filename="latest.tar") -> None:
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
@ -148,16 +160,19 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
:return: None
"""
# add path
checkpoint = os.path.join(path, 'checkpoints')
checkpoint = os.path.join(path, "checkpoints")
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
state = dict()
state['state_dict'] = self.model.state_dict()
state['opt_state_dict'] = self.optimizer.state_dict()
state["state_dict"] = self.model.state_dict()
state["opt_state_dict"] = self.optimizer.state_dict()
logger.info(f"Saving checkpoint state dictionary: {filepath}")
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
shutil.copyfile(filepath, os.path.join(path, "model_best.tar"))
logger.info(f"Saving best state dictionary checkpoint: {os.path.join(path, 'model_best.tar')}")
def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
"""
@ -167,13 +182,15 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
:param filename: checkpoint name
:return: None
"""
checkpoint = os.path.join(path, 'checkpoints')
checkpoint = os.path.join(path, "checkpoints")
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
logger.info(f"Saving checkpoint model : {filepath}")
torch.save(self.model, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
shutil.copyfile(filepath, os.path.join(path, "model_best.tar"))
logger.info(f"Saving best checkpoint model: {os.path.join(path, 'model_best.tar')}")
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
"""
@ -183,7 +200,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
:return: loaded model
"""
if path is None:
path = os.path.join(os.getcwd(), 'checkpoints')
path = os.path.join(os.getcwd(), "checkpoints")
filepath = os.path.join(path, model_name)
if not os.path.exists(filepath):
@ -193,25 +210,26 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
else:
checkpoint = torch.load(filepath)
self.model.load_state_dict(checkpoint['state_dict'])
self.model.load_state_dict(checkpoint["state_dict"])
self.model.to(self.device)
if self._optimizer and 'opt_state_dict' in checkpoint:
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
if self._optimizer and "opt_state_dict" in checkpoint:
self._optimizer.load_state_dict(checkpoint["opt_state_dict"])
self.model.eval()
def load_latest_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (latest.tar)
:return: loaded model
Load model state dict only based on the check point path (latest.tar)
:return: loaded model
"""
self.load_checkpoint_state_dict_by_path('latest.tar')
self.load_checkpoint_state_dict_by_path("latest.tar")
def load_best_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (model_best.tar)
:return: loaded model
Load model state dict only based on the check point path (model_best.tar)
:return: loaded model
"""
self.load_checkpoint_state_dict_by_path('model_best.tar')
self.load_checkpoint_state_dict_by_path("model_best.tar")
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
"""
@ -221,7 +239,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
:return: loaded model
"""
if path is None:
path = os.path.join(os.getcwd(), 'checkpoints')
path = os.path.join(os.getcwd(), "checkpoints")
filepath = os.path.join(path, model_name)
if not os.path.exists(filepath):
@ -230,22 +248,23 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
raise FileNotFoundError(msg)
else:
self._model._model = torch.load(filepath)
self._model._model = torch.load(filepath, map_location=self.device)
self.model.to(self.device)
self.model.eval()
def load_latest_model_checkpoint(self):
"""
Load entire model only based on the check point path (latest.tar)
:return: loaded model
Load entire model only based on the check point path (latest.tar)
:return: loaded model
"""
self.load_checkpoint_model_by_path('latest.tar')
self.load_checkpoint_model_by_path("latest.tar")
def load_best_model_checkpoint(self):
"""
Load entire model only based on the check point path (model_best.tar)
:return: loaded model
Load entire model only based on the check point path (model_best.tar)
:return: loaded model
"""
self.load_checkpoint_model_by_path('model_best.tar')
self.load_checkpoint_model_by_path("model_best.tar")
class PyTorchClassifier(PyTorchModel):
@ -253,9 +272,18 @@ class PyTorchClassifier(PyTorchModel):
Wrapper class for pytorch classification models.
"""
def __init__(self, model: "torch.nn.Module", output_type: ModelOutputType, loss: "torch.nn.modules.loss._Loss",
input_shape: Tuple[int, ...], nb_classes: int, optimizer: "torch.optim.Optimizer",
black_box_access: Optional[bool] = True, unlimited_queries: Optional[bool] = True, **kwargs):
def __init__(
self,
model: "torch.nn.Module",
output_type: ModelOutputType,
loss: "torch.nn.modules.loss._Loss",
input_shape: Tuple[int, ...],
nb_classes: int,
optimizer: "torch.optim.Optimizer",
black_box_access: Optional[bool] = True,
unlimited_queries: Optional[bool] = True,
**kwargs,
):
"""
Initialization specifically for the PyTorch-based implementation.
@ -278,9 +306,17 @@ class PyTorchClassifier(PyTorchModel):
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
def fit(self, train_data: PytorchData, validation_data: PytorchData = None, batch_size: int = 128,
nb_epochs: int = 10,
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
def fit(
self,
train_data: PytorchData,
validation_data: PytorchData = None,
batch_size: int = 128,
nb_epochs: int = 10,
save_checkpoints: bool = True,
save_entire_model=True,
path=os.getcwd(),
**kwargs,
) -> None:
"""
Fit the model using the training data.
@ -296,9 +332,30 @@ class PyTorchClassifier(PyTorchModel):
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
"""
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1),
validation_data.get_samples(), validation_data.get_labels().reshape(-1, 1), batch_size,
nb_epochs, save_checkpoints, save_entire_model, path, **kwargs)
if validation_data is None:
self._art_model.fit(
x=train_data.get_samples(),
y=train_data.get_labels().reshape(-1, 1),
batch_size=batch_size,
nb_epochs=nb_epochs,
save_checkpoints=save_checkpoints,
save_entire_model=save_entire_model,
path=path,
**kwargs,
)
else:
self._art_model.fit(
x=train_data.get_samples(),
y=train_data.get_labels().reshape(-1, 1),
x_validation=validation_data.get_samples(),
y_validation=validation_data.get_labels().reshape(-1, 1),
batch_size=batch_size,
nb_epochs=nb_epochs,
save_checkpoints=save_checkpoints,
save_entire_model=save_entire_model,
path=path,
**kwargs,
)
def predict(self, x: PytorchData, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""
@ -332,15 +389,15 @@ class PyTorchClassifier(PyTorchModel):
def load_latest_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (latest.tar)
:return: loaded model
Load model state dict only based on the check point path (latest.tar)
:return: loaded model
"""
self._art_model.load_latest_state_dict_checkpoint()
def load_best_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (model_best.tar)
:return: loaded model
Load model state dict only based on the check point path (model_best.tar)
:return: loaded model
"""
self._art_model.load_best_state_dict_checkpoint()
@ -355,14 +412,14 @@ class PyTorchClassifier(PyTorchModel):
def load_latest_model_checkpoint(self):
"""
Load entire model only based on the check point path (latest.tar)
:return: loaded model
Load entire model only based on the check point path (latest.tar)
:return: loaded model
"""
self._art_model.load_latest_model_checkpoint()
def load_best_model_checkpoint(self):
"""
Load entire model only based on the check point path (model_best.tar)
:return: loaded model
Load entire model only based on the check point path (model_best.tar)
:return: loaded model
"""
self._art_model.load_best_model_checkpoint()

View file

@ -1,8 +1,13 @@
numpy==1.21.0
pandas==1.1.0
scipy==1.4.1
scikit-learn==0.22.2
numpy~=1.22.3
pandas~=1.1.3
scipy~=1.5.2
scikit-learn~=1.0.2
adversarial-robustness-toolkit>=1.9.1
# testing
pytest==5.4.2
pytest~=6.1.1
torch~=1.11.0
sklearn~=0.0
six~=1.15.0
shap~=0.40.0

View file

@ -57,7 +57,7 @@ def test_nursery_pytorch_state_dict():
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=1000)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
model.load_latest_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)