mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-25 04:46:21 +02:00
fix
This commit is contained in:
parent
59d8b16bb4
commit
8459d6961f
1 changed files with 5 additions and 6 deletions
|
|
@ -15,7 +15,6 @@ from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
|||
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
|
||||
import torch
|
||||
|
||||
|
||||
class PyTorchModel(Model):
|
||||
"""
|
||||
Wrapper class for pytorch models.
|
||||
|
|
@ -31,11 +30,10 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
"""get number of correctly classified labels"""
|
||||
if len(outputs) != len(targets):
|
||||
raise ValueError("outputs and targets should be the same length.")
|
||||
counter = 0
|
||||
for i, o in enumerate(outputs):
|
||||
if o == targets[i]:
|
||||
counter += 1
|
||||
return counter
|
||||
if self.nb_classes > 1:
|
||||
return int(torch.sum(torch.argmax(outputs, axis=-1) == targets).item())
|
||||
else:
|
||||
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
|
||||
|
||||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True, **kwargs) -> None:
|
||||
|
|
@ -115,6 +113,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
"""
|
||||
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
|
||||
path = checkpoint
|
||||
os.makedirs(path, exist_ok=True)
|
||||
filepath = os.path.join(path, filename)
|
||||
state = additional_states if additional_states else dict()
|
||||
state['state_dict'] = self.model.module.state_dict() \
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue