This commit is contained in:
olasaadi 2022-07-20 18:29:48 +03:00
parent 3bf26b67d2
commit 6f69f5557b
2 changed files with 7 additions and 6 deletions

View file

@ -78,6 +78,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
# Put the model in the training mode
self._model.train()
print(nb_epochs)
if self._optimizer is None: # pragma: no cover
raise ValueError("An optimizer is needed to train the model, but none for provided.")
@ -122,7 +123,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
tot_correct += correct
total += o_batch.shape[0]
val_loss, val_acc = self._eval(x, y, num_batch, batch_size)
# print acc TODO
print(val_acc)
best_acc = max(val_acc, best_acc)
if save_checkpoints:
if save_entire_model:
@ -222,7 +223,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
raise FileNotFoundError(msg)
else:
self._model = torch.load(filepath)
self._model._model = torch.load(filepath)
def load_latest_model_checkpoint(self):
"""

View file

@ -57,7 +57,7 @@ def test_nursery_pytorch_state_dict():
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=100)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=1000)
model.load_latest_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
@ -65,7 +65,7 @@ def test_nursery_pytorch_state_dict():
# python pytorch numpy
model.load_best_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
print('best model accuracy: ', score)
assert (0 <= score <= 1)
@ -87,12 +87,12 @@ def test_nursery_pytorch_save_entire_model():
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
art_model.load_best_model_checkpoint()
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
print('best model accuracy: ', score)
assert (0 <= score <= 1)