mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-02 14:45:13 +02:00
fix
This commit is contained in:
parent
4973fbebc6
commit
3bf26b67d2
2 changed files with 5 additions and 1 deletions
|
|
@ -122,6 +122,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
||||||
tot_correct += correct
|
tot_correct += correct
|
||||||
total += o_batch.shape[0]
|
total += o_batch.shape[0]
|
||||||
val_loss, val_acc = self._eval(x, y, num_batch, batch_size)
|
val_loss, val_acc = self._eval(x, y, num_batch, batch_size)
|
||||||
|
# print acc TODO
|
||||||
best_acc = max(val_acc, best_acc)
|
best_acc = max(val_acc, best_acc)
|
||||||
if save_checkpoints:
|
if save_checkpoints:
|
||||||
if save_entire_model:
|
if save_entire_model:
|
||||||
|
|
|
||||||
|
|
@ -58,9 +58,11 @@ def test_nursery_pytorch_state_dict():
|
||||||
optimizer=optimizer, input_shape=(24,),
|
optimizer=optimizer, input_shape=(24,),
|
||||||
nb_classes=4)
|
nb_classes=4)
|
||||||
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100)
|
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100)
|
||||||
|
model.load_latest_state_dict_checkpoint()
|
||||||
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||||
print('Base model accuracy: ', score)
|
print('Base model accuracy: ', score)
|
||||||
assert (0 <= score <= 1)
|
assert (0 <= score <= 1)
|
||||||
|
# python pytorch numpy
|
||||||
model.load_best_state_dict_checkpoint()
|
model.load_best_state_dict_checkpoint()
|
||||||
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||||
print('Base model accuracy: ', score)
|
print('Base model accuracy: ', score)
|
||||||
|
|
@ -68,6 +70,7 @@ def test_nursery_pytorch_state_dict():
|
||||||
|
|
||||||
|
|
||||||
def test_nursery_pytorch_save_entire_model():
|
def test_nursery_pytorch_save_entire_model():
|
||||||
|
|
||||||
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
|
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
|
||||||
# reduce size of training set to make attack slightly better
|
# reduce size of training set to make attack slightly better
|
||||||
train_set_size = 500
|
train_set_size = 500
|
||||||
|
|
@ -90,6 +93,6 @@ def test_nursery_pytorch_save_entire_model():
|
||||||
print('Base model accuracy: ', score)
|
print('Base model accuracy: ', score)
|
||||||
assert (0 <= score <= 1)
|
assert (0 <= score <= 1)
|
||||||
art_model.load_best_model_checkpoint()
|
art_model.load_best_model_checkpoint()
|
||||||
#score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||||
print('Base model accuracy: ', score)
|
print('Base model accuracy: ', score)
|
||||||
assert (0 <= score <= 1)
|
assert (0 <= score <= 1)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue