mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
fix
This commit is contained in:
parent
c954f53ad7
commit
21cba95a28
2 changed files with 43 additions and 8 deletions
|
|
@ -65,7 +65,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
return total_loss / total, float(correct) / total
|
||||
|
||||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True, **kwargs) -> None:
|
||||
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
|
||||
"""
|
||||
Fit the classifier on the training set `(x, y)`.
|
||||
:param x: Training data.
|
||||
|
|
@ -74,6 +74,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
:param batch_size: Size of batches.
|
||||
:param nb_epochs: Number of epochs to use for training.
|
||||
:param save_checkpoints: Boolean, save checkpoints if True.
|
||||
:param save_entire_model: Boolean, save entire model if True, else save state dict.
|
||||
:param path: path for saving checkpoint.
|
||||
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
|
||||
supported for PyTorch and providing it takes no effect.
|
||||
"""
|
||||
|
|
@ -127,18 +129,22 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
train_acc = float(tot_correct) / total
|
||||
best_acc = max(val_acc, best_acc)
|
||||
if save_checkpoints:
|
||||
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
|
||||
if save_entire_model:
|
||||
self.save_checkpoint_model(is_best=best_acc <= val_acc)
|
||||
else:
|
||||
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
|
||||
|
||||
def save_checkpoint_state_dict(self, is_best: bool,
|
||||
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(),
|
||||
filename="latest.tar") -> None:
|
||||
"""
|
||||
Saves checkpoint as latest.tar or best.tar
|
||||
:param is_best: whether the model is the best achieved model
|
||||
:param path: path for saving checkpoint
|
||||
:param filename: checkpoint name
|
||||
:return: None
|
||||
"""
|
||||
# add path
|
||||
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
|
||||
checkpoint = os.path.join(path, 'checkpoints')
|
||||
path = checkpoint
|
||||
os.makedirs(path, exist_ok=True)
|
||||
filepath = os.path.join(path, filename)
|
||||
|
|
@ -149,14 +155,15 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
if is_best:
|
||||
shutil.copyfile(filepath, os.path.join(path, 'model_best.tar'))
|
||||
|
||||
def save_checkpoint_model(self, is_best: bool, filename="latest.tar") -> None:
|
||||
def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
|
||||
"""
|
||||
Saves checkpoint as latest.tar or best.tar
|
||||
:param is_best: whether the model is the best achieved model
|
||||
:param path: path for saving checkpoint
|
||||
:param filename: checkpoint name
|
||||
:return: None
|
||||
"""
|
||||
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
|
||||
checkpoint = os.path.join(path, 'checkpoints')
|
||||
path = checkpoint
|
||||
os.makedirs(path, exist_ok=True)
|
||||
filepath = os.path.join(path, filename)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class pytorch_model(nn.Module):
|
|||
out = self.fc4(out)
|
||||
return self.classifier(out)
|
||||
|
||||
def test_nursery_pytorch():
|
||||
def test_nursery_pytorch_state_dict():
|
||||
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
|
||||
# reduce size of training set to make attack slightly better
|
||||
train_set_size = 500
|
||||
|
|
@ -59,9 +59,37 @@ def test_nursery_pytorch():
|
|||
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
art_model.fit(PytorchData(x_train.astype(np.float32), y_train))
|
||||
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False)
|
||||
|
||||
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
|
||||
|
||||
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
|
||||
art_model.load_best_state_dict_checkpoint()
|
||||
|
||||
|
||||
def test_nursery_pytorch_save_entire_model():
|
||||
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
|
||||
# reduce size of training set to make attack slightly better
|
||||
train_set_size = 500
|
||||
x_train = x_train[:train_set_size]
|
||||
y_train = y_train[:train_set_size]
|
||||
x_test = x_test[:train_set_size]
|
||||
y_test = y_test[:train_set_size]
|
||||
|
||||
|
||||
|
||||
model = pytorch_model(4, 24)
|
||||
# model = torch.nn.DataParallel(model)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
||||
|
||||
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True)
|
||||
|
||||
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
|
||||
|
||||
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
|
||||
art_model.load_best_model_checkpoint()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue