This commit is contained in:
olasaadi 2022-06-02 15:25:07 +03:00
parent a3fb68fb56
commit 302d0c4b8c
3 changed files with 46 additions and 42 deletions

View file

@ -221,21 +221,21 @@ class PytorchData(Dataset):
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return array2numpy(self._x)
return array2numpy(self, self._x)
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return labels as numpy array"""
return array2numpy(self._y) if self._y is not None else None
return array2numpy(self, self._y) if self._y is not None else None
def get_sample_item(self, idx) -> Tensor:
return self.x[idx]
return self._x[idx]
def get_item(self, idx) -> Tensor:
sample, label = self.x[idx], self.y[idx]
sample, label = self._x[idx], self._y[idx]
return sample, label
def __len__(self):
return len(self.x)
return len(self._x)
class DatasetFactory:

View file

@ -10,8 +10,9 @@ from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import Model, ModelOutputType
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE, Data
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
import torch
@ -37,41 +38,46 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
else:
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
def _eval(self, x: np.ndarray, y: np.ndarray):
def _eval(self, x: np.ndarray, y: np.ndarray, nb_epochs, batch_size):
self.model.eval()
total_loss = 0
correct = 0
total = 0
for m in range(len(x)):
inputs = torch.from_numpy(x[m]).to(self._device)
targets = torch.from_numpy(y[m]).to(self._device)
targets = targets.to(self.device)
outputs = self.model(inputs)
loss = self._loss(outputs, targets)
total_loss += (loss.item() * targets.size(0))
total += targets.size(0)
correct += self.get_step_correct(outputs, targets)
y = check_and_transform_label_format(y, self.nb_classes)
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
y_preprocessed = self.reduce_labels(y_preprocessed)
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
ind = np.arange(len(x_preprocessed))
for epoch in range(nb_epochs):
random.shuffle(ind)
for m in range(num_batch):
inputs = torch.from_numpy(x_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
targets = torch.from_numpy(y_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device)
targets = targets.to(self.device)
outputs = self.model(inputs)
loss = self._loss(outputs, targets)
total_loss += (loss.item() * targets.size(0))
total += targets.size(0)
correct += self.get_step_correct(outputs, targets)
return total_loss / total, float(correct) / total
def fit(self, X: np.ndarray, Y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
save_checkpoints: bool = True, **kwargs) -> None:
"""
Fit the classifier on the training set `(x, y)`.
:param X: Training data.
:param Y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
shape (nb_samples,).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
and providing it takes no effect.
"""
Fit the classifier on the training set `(x, y)`.
:param x: Training data.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
of shape (nb_samples,).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
"""
# Put the model in the training mode
x, X_test, y, y_test = train_test_split(X, Y, test_size=0.33, random_state=42)
self._model.train()
if self._optimizer is None: # pragma: no cover
@ -117,12 +123,11 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
correct = self.get_step_correct(model_outputs[-1], o_batch)
tot_correct += correct
total += o_batch.shape[0]
val_loss, val_acc = self._eval(X_test, y_test)
val_loss, val_acc = self._eval(x, y, num_batch, batch_size)
train_acc = float(tot_correct) / total
if save_checkpoints:
self.save_checkpoint_state_dict(is_best=best_acc <= val_acc)
best_acc = max(val_acc, best_acc)
if save_checkpoints:
self.save_checkpoint_model(is_best=best_acc <= val_acc)
def save_checkpoint_state_dict(self, is_best: bool, additional_states: Dict = None,
filename="latest.tar") -> None:
@ -133,6 +138,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
:param filename: checkpoint name
:return: None
"""
# add path
checkpoint = os.path.join(os.getcwd(), 'checkpoints')
path = checkpoint
os.makedirs(path, exist_ok=True)
@ -261,7 +267,7 @@ class PyTorchClassifier(PyTorchModel):
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
def fit(self, train_data: Dataset, **kwargs) -> None:
def fit(self, train_data: PytorchData, **kwargs) -> None:
"""
Fit the model using the training data.

View file

@ -2,7 +2,8 @@ import numpy as np
import torch
from torch import nn, optim
from apt.utils.datasets import ArrayDataset
from apt.utils.datasets import ArrayDataset, Data, Dataset
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import ModelOutputType
from apt.utils.models.pytorch_model import PyTorchClassifier
from art.utils import load_nursery
@ -54,14 +55,11 @@ def test_nursery_pytorch():
optimizer = optim.Adam(model.parameters(), lr=0.01)
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
art_model.fit(ArrayDataset(x_train.astype(np.float32), y_train))
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train))
pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))])
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
art_model.load_best_state_dict_checkpoint()
art_model.load_best_model_checkpoint()