This commit is contained in:
olasaadi 2022-07-19 21:16:39 +03:00
parent 07e64b1f86
commit 4973fbebc6
2 changed files with 34 additions and 23 deletions

View file

@ -23,6 +23,7 @@ class PyTorchModel(Model):
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
Wrapper class for pytorch classifier model.
Extension for Pytorch ART model
"""
def get_step_correct(self, outputs, targets) -> int:
@ -187,6 +188,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
if self._optimizer and 'opt_state_dict' in checkpoint:
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
self.model.eval()
def load_latest_state_dict_checkpoint(self):
"""
@ -266,14 +268,23 @@ class PyTorchClassifier(PyTorchModel):
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
def fit(self, train_data: PytorchData, **kwargs) -> None:
def fit(self, train_data: PytorchData, batch_size: int = 128, nb_epochs: int = 10,
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
"""
Fit the model using the training data.
:param train_data: Training data.
:type train_data: `Dataset`
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param save_entire_model: Boolean, save entire model if True, else save state dict.
:param path: path for saving checkpoint.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
"""
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), **kwargs)
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), batch_size, nb_epochs,
save_checkpoints, save_entire_model, path, **kwargs)
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""

View file

@ -2,7 +2,7 @@ import numpy as np
import torch
from torch import nn, optim
from apt.utils.datasets import ArrayDataset, Data, Dataset
from apt.utils.datasets import ArrayDataset
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import ModelOutputType
from apt.utils.models.pytorch_model import PyTorchClassifier
@ -40,6 +40,7 @@ class pytorch_model(nn.Module):
out = self.fc4(out)
return self.classifier(out)
def test_nursery_pytorch_state_dict():
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
# reduce size of training set to make attack slightly better
@ -49,22 +50,21 @@ def test_nursery_pytorch_state_dict():
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
model = pytorch_model(4, 24)
# model = torch.nn.DataParallel(model)
inner_model = pytorch_model(4, 24)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False)
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
art_model.load_best_state_dict_checkpoint()
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100)
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
model.load_best_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
def test_nursery_pytorch_save_entire_model():
@ -76,8 +76,6 @@ def test_nursery_pytorch_save_entire_model():
x_test = x_test[:train_set_size]
y_test = y_test[:train_set_size]
model = pytorch_model(4, 24)
# model = torch.nn.DataParallel(model)
criterion = nn.CrossEntropyLoss()
@ -88,8 +86,10 @@ def test_nursery_pytorch_save_entire_model():
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True)
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
art_model.load_best_model_checkpoint()
#score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)