update pytorch wrapper to use torch loaders

fix tests
and dataset style
This commit is contained in:
Ron Shmelkin 2022-07-24 14:31:47 +03:00
parent fdc6005fce
commit c77e34e373
No known key found for this signature in database
GPG key ID: A4289A6607B5C294
4 changed files with 178 additions and 113 deletions

View file

@ -57,7 +57,7 @@ def test_nursery_pytorch_state_dict():
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=1000)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
model.load_latest_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)