Merge pull request #58 from IBM/pytorch_wrapper

Wrapper for Pytorch models
This commit is contained in:
abigailgold 2022-08-02 17:23:46 +03:00 committed by GitHub
commit 1385f31dcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 585 additions and 64 deletions

View file

@ -19,9 +19,42 @@ from torch import Tensor
logger = logging.getLogger(__name__)
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, pd.Series, List, Tensor]
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame, pd.Series]
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
"""
converts from INPUT_DATA_ARRAY_TYPE to numpy array
"""
if type(arr) == np.ndarray:
return arr
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
return arr.to_numpy()
if isinstance(arr, list):
return np.array(arr)
if type(arr) == Tensor:
return arr.detach().cpu().numpy()
raise ValueError("Non supported type: ", type(arr).__name__)
def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
"""
converts from INPUT_DATA_ARRAY_TYPE to torch tensor array
"""
if type(arr) == np.ndarray:
return torch.from_numpy(arr)
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
return torch.from_numpy(arr.to_numpy())
if isinstance(arr, list):
return torch.tensor(arr)
if type(arr) == Tensor:
return arr
raise ValueError("Non supported type: ", type(arr).__name__)
class Dataset(metaclass=ABCMeta):
@ -58,46 +91,6 @@ class Dataset(metaclass=ABCMeta):
"""
raise NotImplementedError
def _array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
"""
Converts from INPUT_DATA_ARRAY_TYPE to numpy array
:param arr: the array to transform
:type arr: numpy array or pandas DataFrame or list or pytorch Tensor
:return: the array transformed into a numpy array
"""
if type(arr) == np.ndarray:
return arr
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
self.is_pandas = True
return arr.to_numpy()
if isinstance(arr, list):
return np.array(arr)
if type(arr) == Tensor:
return arr.detach().cpu().numpy()
raise ValueError('Non supported type: ', type(arr).__name__)
def _array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
"""
Converts from INPUT_DATA_ARRAY_TYPE to torch tensor array
:param arr: the array to transform
:type arr: numpy array or pandas DataFrame or list or pytorch Tensor
:return: the array transformed into a pytorch Tensor
"""
if type(arr) == np.ndarray:
return torch.from_numpy(arr)
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
self.is_pandas = True
return torch.from_numpy(arr.to_numpy())
if isinstance(arr, list):
return torch.tensor(arr)
if type(arr) == Tensor:
return arr
raise ValueError('Non supported type: ', type(arr).__name__)
class StoredDataset(Dataset):
"""Abstract Class for a Dataset that can be downloaded from a URL and stored in a file"""
@ -146,7 +139,7 @@ class StoredDataset(Dataset):
os.makedirs(dest_path, exist_ok=True)
logger.info("Downloading the dataset...")
urllib.request.urlretrieve(url, file_path)
logger.info('Dataset Downloaded')
logger.info("Dataset Downloaded")
if unzip:
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
@ -205,7 +198,7 @@ class StoredDataset(Dataset):
logger.info("Shuffling data")
np.random.shuffle(data)
debug_data = data[:int(len(data) * ratio)]
debug_data = data[: int(len(data) * ratio)]
logger.info(f"Saving {ratio} of the data to {dest_datafile}")
np.savetxt(dest_datafile, debug_data, delimiter=delimiter, fmt=fmt)
@ -224,17 +217,19 @@ class ArrayDataset(Dataset):
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None,
features_names: Optional[list] = None, **kwargs):
self.is_pandas = False
self.is_pandas = self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series
self.features_names = features_names
self._y = self._array2numpy(y) if y is not None else None
self._x = self._array2numpy(x)
self._y = array2numpy(y) if y is not None else None
self._x = array2numpy(x)
if self.is_pandas:
if features_names and not np.array_equal(features_names, x.columns):
raise ValueError("The supplied features are not the same as in the data features")
self.features_names = x.columns.to_list()
if self._y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')
raise ValueError("Non equivalent lengths of x and y")
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""
@ -278,9 +273,9 @@ class DatasetWithPredictions(Dataset):
y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names: Optional[list] = None, **kwargs):
self.is_pandas = False
self.features_names = features_names
self._pred = self._array2numpy(pred)
self._y = self._array2numpy(y) if y is not None else None
self._x = self._array2numpy(x) if x is not None else None
self._pred = array2numpy(pred)
self._y = array2numpy(y) if y is not None else None
self._x = array2numpy(x) if x is not None else None
if self.is_pandas and x is not None:
if features_names and not np.array_equal(features_names, x.columns):
raise ValueError("The supplied features are not the same as in the data features")
@ -327,14 +322,16 @@ class PytorchData(Dataset):
:type y: numpy array or pandas DataFrame or list or pytorch Tensor, optional
"""
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
self.is_pandas = False
self._y = self._array2torch_tensor(y) if y is not None else None
self._x = self._array2torch_tensor(x)
self._y = array2torch_tensor(y) if y is not None else None
self._x = array2torch_tensor(x)
self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series
if self.is_pandas:
self.features_names = x.columns
if self._y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')
raise ValueError("Non equivalent lengths of x and y")
if self._y is not None:
self.__getitem__ = self.get_item
@ -347,7 +344,7 @@ class PytorchData(Dataset):
:return: samples as numpy array
"""
return self._array2numpy(self._x)
return array2numpy(self._x)
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""
@ -355,7 +352,7 @@ class PytorchData(Dataset):
:return: labels as numpy array
"""
return self._array2numpy(self._y) if self._y is not None else None
return array2numpy(self._y) if self._y is not None else None
def get_predictions(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""
@ -392,6 +389,7 @@ class PytorchData(Dataset):
class DatasetFactory:
"""Factory class for dataset creation"""
registry = {}
@classmethod
@ -406,7 +404,7 @@ class DatasetFactory:
def inner_wrapper(wrapped_class: Type[Dataset]) -> Any:
if name in cls.registry:
logger.warning('Dataset %s already exists. Will replace it', name)
logger.warning("Dataset %s already exists. Will replace it", name)
cls.registry[name] = wrapped_class
return wrapped_class
@ -428,7 +426,7 @@ class DatasetFactory:
:return: An instance of the dataset that is created.
"""
if name not in cls.registry:
msg = f'Dataset {name} does not exist in the registry'
msg = f"Dataset {name} does not exist in the registry"
logger.error(msg)
raise ValueError(msg)

View file

@ -0,0 +1,425 @@
""" Pytorch Model Wrapper"""
import os
import shutil
import logging
from typing import Optional, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from art.utils import check_and_transform_label_format, logger
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import Model, ModelOutputType
from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
logger = logging.getLogger(__name__)
class PyTorchModel(Model):
"""
Wrapper class for pytorch models.
"""
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
Wrapper class for pytorch classifier model.
Extension for Pytorch ART model
"""
def get_step_correct(self, outputs, targets) -> int:
"""get number of correctly classified labels"""
if len(outputs) != len(targets):
raise ValueError("outputs and targets should be the same length.")
if self.nb_classes > 1:
return int(torch.sum(torch.argmax(outputs, axis=-1) == targets).item())
else:
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
def _eval(self, loader: DataLoader):
"""inner function for model evaluation"""
self.model.eval()
total_loss = 0
correct = 0
total = 0
for inputs, targets in loader:
inputs = inputs.to(self._device)
targets = targets.to(self._device)
outputs = self.model(inputs)
loss = self._loss(outputs, targets)
total_loss += loss.item() * targets.size(0)
total += targets.size(0)
correct += self.get_step_correct(outputs, targets)
return total_loss / total, float(correct) / total
def fit(
self,
x: np.ndarray,
y: np.ndarray,
x_validation: np.ndarray = None,
y_validation: np.ndarray = None,
batch_size: int = 128,
nb_epochs: int = 10,
save_checkpoints: bool = True,
save_entire_model=True,
path=os.getcwd(),
**kwargs,
) -> None:
"""
Fit the classifier on the training set `(x, y)`.
:param x: Training data.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
of shape (nb_samples,).
:param x_validation: Validation data (optional).
:param y_validation: Target validation values (class labels) one-hot-encoded of shape
(nb_samples, nb_classes) or index labels of shape (nb_samples,) (optional).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param save_entire_model: Boolean, save entire model if True, else save state dict.
:param path: path for saving checkpoint.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
"""
# Put the model in the training mode
self._model.train()
if self._optimizer is None: # pragma: no cover
raise ValueError("An optimizer is needed to train the model, but none for provided.")
y = check_and_transform_label_format(y, self.nb_classes)
# Apply preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)
train_dataset = TensorDataset(torch.from_numpy(x_preprocessed), torch.from_numpy(y_preprocessed))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
if x_validation is None or y_validation is None:
val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
logger.info("Using train set for validation")
else:
y_val = check_and_transform_label_format(y_validation, self.nb_classes)
x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, y_val, fit=False)
# Check label shape
y_val_preprocessed = self.reduce_labels(y_val_preprocessed)
val_dataset = TensorDataset(torch.from_numpy(x_val_preprocessed), torch.from_numpy(y_val_preprocessed))
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
# Start training
for epoch in range(nb_epochs):
tot_correct = 0
total = 0
best_acc = 0
# Shuffle the examples
# Train for one epoch
for inputs, targets in train_loader:
inputs = inputs.to(self._device)
targets = targets.to(self._device)
# Zero the parameter gradients
self._optimizer.zero_grad()
# Perform prediction
model_outputs = self._model(inputs)
# Form the loss function
loss = self._loss(model_outputs[-1], targets)
loss.backward()
self._optimizer.step()
correct = self.get_step_correct(model_outputs[-1], targets)
tot_correct += correct
total += targets.shape[0]
val_loss, val_acc = self._eval(val_loader)
logger.info(f"Epoch{epoch + 1}/{nb_epochs} Val_Loss: {val_loss}, Val_Acc: {val_acc}")
best_acc = max(val_acc, best_acc)
if save_checkpoints:
if save_entire_model:
self.save_checkpoint_model(is_best=best_acc <= val_acc, path=path)
else:
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc, path=path)
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
:param path: path for saving checkpoint
:param filename: checkpoint name
:return: None
"""
# add path
checkpoint = os.path.join(path, "checkpoints")
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
state = dict()
state["state_dict"] = self.model.state_dict()
state["opt_state_dict"] = self.optimizer.state_dict()
logger.info(f"Saving checkpoint state dictionary: {filepath}")
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(path, "model_best.tar"))
logger.info(f"Saving best state dictionary checkpoint: {os.path.join(path, 'model_best.tar')}")
def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
:param is_best: whether the model is the best achieved model
:param path: path for saving checkpoint
:param filename: checkpoint name
:return: None
"""
checkpoint = os.path.join(path, "checkpoints")
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
logger.info(f"Saving checkpoint model : {filepath}")
torch.save(self.model, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(path, "model_best.tar"))
logger.info(f"Saving best checkpoint model: {os.path.join(path, 'model_best.tar')}")
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
"""
if path is None:
path = os.path.join(os.getcwd(), "checkpoints")
filepath = os.path.join(path, model_name)
if not os.path.exists(filepath):
msg = f"Model file {filepath} not found"
logger.error(msg)
raise FileNotFoundError(msg)
else:
checkpoint = torch.load(filepath)
self.model.load_state_dict(checkpoint["state_dict"])
self.model.to(self.device)
if self._optimizer and "opt_state_dict" in checkpoint:
self._optimizer.load_state_dict(checkpoint["opt_state_dict"])
self.model.eval()
def load_latest_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (latest.tar)
:return: loaded model
"""
self.load_checkpoint_state_dict_by_path("latest.tar")
def load_best_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (model_best.tar)
:return: loaded model
"""
self.load_checkpoint_state_dict_by_path("model_best.tar")
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
"""
if path is None:
path = os.path.join(os.getcwd(), "checkpoints")
filepath = os.path.join(path, model_name)
if not os.path.exists(filepath):
msg = f"Model file {filepath} not found"
logger.error(msg)
raise FileNotFoundError(msg)
else:
self._model._model = torch.load(filepath, map_location=self.device)
self.model.to(self.device)
self.model.eval()
def load_latest_model_checkpoint(self):
"""
Load entire model only based on the check point path (latest.tar)
:return: loaded model
"""
self.load_checkpoint_model_by_path("latest.tar")
def load_best_model_checkpoint(self):
"""
Load entire model only based on the check point path (model_best.tar)
:return: loaded model
"""
self.load_checkpoint_model_by_path("model_best.tar")
class PyTorchClassifier(PyTorchModel):
"""
Wrapper class for pytorch classification models.
"""
def __init__(
self,
model: "torch.nn.Module",
output_type: ModelOutputType,
loss: "torch.nn.modules.loss._Loss",
input_shape: Tuple[int, ...],
nb_classes: int,
optimizer: "torch.optim.Optimizer",
black_box_access: Optional[bool] = True,
unlimited_queries: Optional[bool] = True,
**kwargs,
):
"""
Initialization specifically for the PyTorch-based implementation.
:param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits
output should be preferred where possible to ensure attack efficiency.
:param output_type: The type of output the model yields (vector/label only for classifiers,
value for regressors)
:param loss: The loss function for which to compute gradients for training. The target label must be raw
categorical, i.e. not converted to one-hot encoding.
:param input_shape: The shape of one input instance.
:param optimizer: The optimizer used to train the classifier.
:param black_box_access: Boolean describing the type of deployment of the model (when in production).
Set to True if the model is only available via query (API) access, i.e.,
only the outputs of the model are exposed, and False if the model internals
are also available. Optional, Default is True.
:param unlimited_queries: If black_box_access is True, this boolean indicates whether a user can perform
unlimited queries to the model API or whether there is a limit to the number of
queries that can be submitted. Optional, Default is True.
"""
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
def fit(
self,
train_data: PytorchData,
validation_data: PytorchData = None,
batch_size: int = 128,
nb_epochs: int = 10,
save_checkpoints: bool = True,
save_entire_model=True,
path=os.getcwd(),
**kwargs,
) -> None:
"""
Fit the model using the training data.
:param train_data: Training data.
:type train_data: `PytorchData`
:param validation_data: Training data.
:type train_data: `PytorchData`
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param save_entire_model: Boolean, save entire model if True, else save state dict.
:param path: path for saving checkpoint.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
"""
if validation_data is None:
self._art_model.fit(
x=train_data.get_samples(),
y=train_data.get_labels().reshape(-1, 1),
batch_size=batch_size,
nb_epochs=nb_epochs,
save_checkpoints=save_checkpoints,
save_entire_model=save_entire_model,
path=path,
**kwargs,
)
else:
self._art_model.fit(
x=train_data.get_samples(),
y=train_data.get_labels().reshape(-1, 1),
x_validation=validation_data.get_samples(),
y_validation=validation_data.get_labels().reshape(-1, 1),
batch_size=batch_size,
nb_epochs=nb_epochs,
save_checkpoints=save_checkpoints,
save_entire_model=save_entire_model,
path=path,
**kwargs,
)
def predict(self, x: PytorchData, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""
Perform predictions using the model for input `x`.
:param x: Input samples.
:type x: `np.ndarray` or `pandas.DataFrame`
:return: Predictions from the model (class probabilities, if supported).
"""
return self._art_model.predict(x.get_samples(), **kwargs)
def score(self, test_data: PytorchData, **kwargs):
"""
Score the model using test data.
:param test_data: Test data.
:type test_data: `PytorchData`
:return: the score as float (between 0 and 1)
"""
y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes)
predicted = self.predict(test_data)
return np.count_nonzero(np.argmax(y, axis=1) == np.argmax(predicted, axis=1)) / predicted.shape[0]
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
"""
self._art_model.load_checkpoint_state_dict_by_path(model_name, path)
def load_latest_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (latest.tar)
:return: loaded model
"""
self._art_model.load_latest_state_dict_checkpoint()
def load_best_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (model_best.tar)
:return: loaded model
"""
self._art_model.load_best_state_dict_checkpoint()
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
"""
self._art_model.load_checkpoint_model_by_path(model_name, path)
def load_latest_model_checkpoint(self):
"""
Load entire model only based on the check point path (latest.tar)
:return: loaded model
"""
self._art_model.load_latest_model_checkpoint()
def load_best_model_checkpoint(self):
"""
Load entire model only based on the check point path (model_best.tar)
:return: loaded model
"""
self._art_model.load_best_model_checkpoint()

View file

@ -1,9 +1,9 @@
numpy>=1.22
pandas==1.1.0
scipy==1.4.1
scikit-learn==0.22.2
numpy~=1.22
pandas~=1.1.0
scipy>=1.4.1
scikit-learn>=0.22.2
torch>=1.8.0
adversarial-robustness-toolbox>=1.11.0
# testing
pytest==5.4.2
pytest>=5.4.2

98
tests/test_pytorch.py Normal file
View file

@ -0,0 +1,98 @@
import numpy as np
import torch
from torch import nn, optim
from apt.utils.datasets import ArrayDataset
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import ModelOutputType
from apt.utils.models.pytorch_model import PyTorchClassifier
from art.utils import load_nursery
class pytorch_model(nn.Module):
def __init__(self, num_classes, num_features):
super(pytorch_model, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 1024),
nn.Tanh(), )
self.fc2 = nn.Sequential(
nn.Linear(1024, 512),
nn.Tanh(), )
self.fc3 = nn.Sequential(
nn.Linear(512, 256),
nn.Tanh(), )
self.fc4 = nn.Sequential(
nn.Linear(256, 128),
nn.Tanh(),
)
self.classifier = nn.Linear(128, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
out = self.fc3(out)
out = self.fc4(out)
return self.classifier(out)
def test_nursery_pytorch_state_dict():
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
# reduce size of training set to make attack slightly better
train_set_size = 500
x_train = x_train[:train_set_size]
y_train = y_train[:train_set_size]
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
inner_model = pytorch_model(4, 24)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
model.load_latest_state_dict_checkpoint()
score = model.score(PytorchData(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
# python pytorch numpy
model.load_best_state_dict_checkpoint()
score = model.score(PytorchData(x_test.astype(np.float32), y_test))
print('best model accuracy: ', score)
assert (0 <= score <= 1)
def test_nursery_pytorch_save_entire_model():
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
# reduce size of training set to make attack slightly better
train_set_size = 500
x_train = x_train[:train_set_size]
y_train = y_train[:train_set_size]
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
model = pytorch_model(4, 24)
# model = torch.nn.DataParallel(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)
score = art_model.score(PytorchData(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
art_model.load_best_model_checkpoint()
score = art_model.score(PytorchData(x_test.astype(np.float32), y_test))
print('best model accuracy: ', score)
assert (0 <= score <= 1)