mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-05 01:32:38 +02:00
fix
This commit is contained in:
parent
019f49861d
commit
59d8b16bb4
2 changed files with 66 additions and 1 deletions
|
|
@ -94,7 +94,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
loss.backward()
|
||||
|
||||
self._optimizer.step()
|
||||
correct = self.get_step_correct(model_outputs, o_batch)
|
||||
correct = self.get_step_correct(model_outputs[-1], o_batch)
|
||||
tot_correct += correct
|
||||
total += o_batch.shape[0]
|
||||
train_acc = float(tot_correct) / total
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue