mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-05 14:55:13 +02:00
update pytorch wrapper to use torch loaders
fix tests and dataset style
This commit is contained in:
parent
fdc6005fce
commit
c77e34e373
4 changed files with 178 additions and 113 deletions
|
|
@ -57,7 +57,7 @@ def test_nursery_pytorch_state_dict():
|
|||
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=1000)
|
||||
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
|
||||
model.load_latest_state_dict_checkpoint()
|
||||
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||
print('Base model accuracy: ', score)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue