mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
fix
This commit is contained in:
parent
af7d615628
commit
07e64b1f86
1 changed files with 1 additions and 3 deletions
|
|
@ -96,7 +96,6 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
for epoch in range(nb_epochs):
|
||||
tot_correct = 0
|
||||
total = 0
|
||||
val_acc = 0
|
||||
best_acc = 0
|
||||
# Shuffle the examples
|
||||
random.shuffle(ind)
|
||||
|
|
@ -113,7 +112,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
model_outputs = self._model(i_batch)
|
||||
|
||||
# Form the loss function
|
||||
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
|
||||
loss = self._loss(model_outputs[-1], o_batch)
|
||||
|
||||
loss.backward()
|
||||
|
||||
|
|
@ -122,7 +121,6 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
tot_correct += correct
|
||||
total += o_batch.shape[0]
|
||||
val_loss, val_acc = self._eval(x, y, num_batch, batch_size)
|
||||
train_acc = float(tot_correct) / total
|
||||
best_acc = max(val_acc, best_acc)
|
||||
if save_checkpoints:
|
||||
if save_entire_model:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue