This commit is contained in:
olasaadi 2022-05-23 13:35:09 +03:00
parent 59d8b16bb4
commit 8459d6961f

View file

@ -15,7 +15,6 @@ from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
import torch
class PyTorchModel(Model):
"""
Wrapper class for pytorch models.
@ -31,11 +30,10 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""get number of correctly classified labels"""
if len(outputs) != len(targets):
raise ValueError("outputs and targets should be the same length.")
counter = 0
for i, o in enumerate(outputs):
if o == targets[i]:
counter += 1
return counter
if self.nb_classes > 1:
return int(torch.sum(torch.argmax(outputs, axis=-1) == targets).item())
else:
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
save_checkpoints: bool = True, **kwargs) -> None:
@ -115,6 +113,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
path = checkpoint
os.makedirs(path, exist_ok=True)
filepath = os.path.join(path, filename)
state = additional_states if additional_states else dict()
state['state_dict'] = self.model.module.state_dict() \