diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 13627cb..a7f15e2 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -78,7 +78,6 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): """ # Put the model in the training mode self._model.train() - print(nb_epochs) if self._optimizer is None: # pragma: no cover raise ValueError("An optimizer is needed to train the model, but none for provided.") @@ -224,6 +223,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): else: self._model._model = torch.load(filepath) + self.model.eval() def load_latest_model_checkpoint(self): """