This commit is contained in:
olasaadi 2022-07-04 12:58:35 +03:00
parent af7d615628
commit 07e64b1f86

View file

@ -96,7 +96,6 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
for epoch in range(nb_epochs):
tot_correct = 0
total = 0
val_acc = 0
best_acc = 0
# Shuffle the examples
random.shuffle(ind)
@ -113,7 +112,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
model_outputs = self._model(i_batch)
# Form the loss function
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
loss = self._loss(model_outputs[-1], o_batch)
loss.backward()
@ -122,7 +121,6 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
tot_correct += correct
total += o_batch.shape[0]
val_loss, val_acc = self._eval(x, y, num_batch, batch_size)
train_acc = float(tot_correct) / total
best_acc = max(val_acc, best_acc)
if save_checkpoints:
if save_entire_model: