This commit is contained in:
olasaadi 2022-07-20 18:36:58 +03:00
parent 6f69f5557b
commit c2c7a01078

View file

@ -78,7 +78,6 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
# Put the model in the training mode
self._model.train()
print(nb_epochs)
if self._optimizer is None: # pragma: no cover
raise ValueError("An optimizer is needed to train the model, but none for provided.")
@ -224,6 +223,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
else:
self._model._model = torch.load(filepath)
self.model.eval()
def load_latest_model_checkpoint(self):
"""