diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 3805a86..1fd2a02 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -96,7 +96,6 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): for epoch in range(nb_epochs): tot_correct = 0 total = 0 - val_acc = 0 best_acc = 0 # Shuffle the examples random.shuffle(ind) @@ -113,7 +112,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): model_outputs = self._model(i_batch) # Form the loss function - loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable] + loss = self._loss(model_outputs[-1], o_batch) loss.backward() @@ -122,7 +121,6 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): tot_correct += correct total += o_batch.shape[0] val_loss, val_acc = self._eval(x, y, num_batch, batch_size) - train_acc = float(tot_correct) / total best_acc = max(val_acc, best_acc) if save_checkpoints: if save_entire_model: