mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-25 04:46:21 +02:00
fix
This commit is contained in:
parent
07e64b1f86
commit
4973fbebc6
2 changed files with 34 additions and 23 deletions
|
|
@ -23,6 +23,7 @@ class PyTorchModel(Model):
|
||||||
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
||||||
"""
|
"""
|
||||||
Wrapper class for pytorch classifier model.
|
Wrapper class for pytorch classifier model.
|
||||||
|
Extension for Pytorch ART model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_step_correct(self, outputs, targets) -> int:
|
def get_step_correct(self, outputs, targets) -> int:
|
||||||
|
|
@ -187,6 +188,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
||||||
|
|
||||||
if self._optimizer and 'opt_state_dict' in checkpoint:
|
if self._optimizer and 'opt_state_dict' in checkpoint:
|
||||||
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
|
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
def load_latest_state_dict_checkpoint(self):
|
def load_latest_state_dict_checkpoint(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -266,14 +268,23 @@ class PyTorchClassifier(PyTorchModel):
|
||||||
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
|
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
|
||||||
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
|
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
|
||||||
|
|
||||||
def fit(self, train_data: PytorchData, **kwargs) -> None:
|
def fit(self, train_data: PytorchData, batch_size: int = 128, nb_epochs: int = 10,
|
||||||
|
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Fit the model using the training data.
|
Fit the model using the training data.
|
||||||
|
|
||||||
:param train_data: Training data.
|
:param train_data: Training data.
|
||||||
:type train_data: `Dataset`
|
:type train_data: `Dataset`
|
||||||
|
:param batch_size: Size of batches.
|
||||||
|
:param nb_epochs: Number of epochs to use for training.
|
||||||
|
:param save_checkpoints: Boolean, save checkpoints if True.
|
||||||
|
:param save_entire_model: Boolean, save entire model if True, else save state dict.
|
||||||
|
:param path: path for saving checkpoint.
|
||||||
|
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
|
||||||
|
supported for PyTorch and providing it takes no effect.
|
||||||
"""
|
"""
|
||||||
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), **kwargs)
|
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), batch_size, nb_epochs,
|
||||||
|
save_checkpoints, save_entire_model, path, **kwargs)
|
||||||
|
|
||||||
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
|
||||||
from apt.utils.datasets import ArrayDataset, Data, Dataset
|
from apt.utils.datasets import ArrayDataset
|
||||||
from apt.utils.datasets.datasets import PytorchData
|
from apt.utils.datasets.datasets import PytorchData
|
||||||
from apt.utils.models import ModelOutputType
|
from apt.utils.models import ModelOutputType
|
||||||
from apt.utils.models.pytorch_model import PyTorchClassifier
|
from apt.utils.models.pytorch_model import PyTorchClassifier
|
||||||
|
|
@ -40,6 +40,7 @@ class pytorch_model(nn.Module):
|
||||||
out = self.fc4(out)
|
out = self.fc4(out)
|
||||||
return self.classifier(out)
|
return self.classifier(out)
|
||||||
|
|
||||||
|
|
||||||
def test_nursery_pytorch_state_dict():
|
def test_nursery_pytorch_state_dict():
|
||||||
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
|
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
|
||||||
# reduce size of training set to make attack slightly better
|
# reduce size of training set to make attack slightly better
|
||||||
|
|
@ -49,22 +50,21 @@ def test_nursery_pytorch_state_dict():
|
||||||
x_test = x_test[:train_set_size]
|
x_test = x_test[:train_set_size]
|
||||||
y_test = y_test[:train_set_size]
|
y_test = y_test[:train_set_size]
|
||||||
|
|
||||||
|
inner_model = pytorch_model(4, 24)
|
||||||
|
|
||||||
model = pytorch_model(4, 24)
|
|
||||||
# model = torch.nn.DataParallel(model)
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
|
||||||
|
|
||||||
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||||
optimizer=optimizer, input_shape=(24,),
|
optimizer=optimizer, input_shape=(24,),
|
||||||
nb_classes=4)
|
nb_classes=4)
|
||||||
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False)
|
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100)
|
||||||
|
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||||
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
|
print('Base model accuracy: ', score)
|
||||||
|
assert (0 <= score <= 1)
|
||||||
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
|
model.load_best_state_dict_checkpoint()
|
||||||
art_model.load_best_state_dict_checkpoint()
|
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||||
|
print('Base model accuracy: ', score)
|
||||||
|
assert (0 <= score <= 1)
|
||||||
|
|
||||||
|
|
||||||
def test_nursery_pytorch_save_entire_model():
|
def test_nursery_pytorch_save_entire_model():
|
||||||
|
|
@ -76,8 +76,6 @@ def test_nursery_pytorch_save_entire_model():
|
||||||
x_test = x_test[:train_set_size]
|
x_test = x_test[:train_set_size]
|
||||||
y_test = y_test[:train_set_size]
|
y_test = y_test[:train_set_size]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model = pytorch_model(4, 24)
|
model = pytorch_model(4, 24)
|
||||||
# model = torch.nn.DataParallel(model)
|
# model = torch.nn.DataParallel(model)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
@ -88,8 +86,10 @@ def test_nursery_pytorch_save_entire_model():
|
||||||
nb_classes=4)
|
nb_classes=4)
|
||||||
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True)
|
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True)
|
||||||
|
|
||||||
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
|
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||||
|
print('Base model accuracy: ', score)
|
||||||
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
|
assert (0 <= score <= 1)
|
||||||
art_model.load_best_model_checkpoint()
|
art_model.load_best_model_checkpoint()
|
||||||
|
#score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||||
|
print('Base model accuracy: ', score)
|
||||||
|
assert (0 <= score <= 1)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue