This commit is contained in:
olasaadi 2022-05-23 12:49:38 +03:00
parent 019f49861d
commit 59d8b16bb4
2 changed files with 66 additions and 1 deletions

View file

@ -94,7 +94,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
loss.backward()
self._optimizer.step()
correct = self.get_step_correct(model_outputs, o_batch)
correct = self.get_step_correct(model_outputs[-1], o_batch)
tot_correct += correct
total += o_batch.shape[0]
train_acc = float(tot_correct) / total