diff --git a/apt/utils/datasets/datasets.py b/apt/utils/datasets/datasets.py index ff7c296..8f80f15 100644 --- a/apt/utils/datasets/datasets.py +++ b/apt/utils/datasets/datasets.py @@ -221,21 +221,21 @@ class PytorchData(Dataset): def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE: """Return data samples as numpy array""" - return array2numpy(self._x) + return array2numpy(self, self._x) def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE: """Return labels as numpy array""" - return array2numpy(self._y) if self._y is not None else None + return array2numpy(self, self._y) if self._y is not None else None def get_sample_item(self, idx) -> Tensor: - return self.x[idx] + return self._x[idx] def get_item(self, idx) -> Tensor: - sample, label = self.x[idx], self.y[idx] + sample, label = self._x[idx], self._y[idx] return sample, label def __len__(self): - return len(self.x) + return len(self._x) class DatasetFactory: diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index c4fbb08..18d57b5 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -10,8 +10,9 @@ from sklearn.model_selection import train_test_split from sklearn.preprocessing import OneHotEncoder +from apt.utils.datasets.datasets import PytorchData from apt.utils.models import Model, ModelOutputType -from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE +from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE, Data from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier import torch @@ -37,41 +38,46 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): else: return int(torch.sum(torch.round(outputs, axis=-1) == targets).item()) - def _eval(self, x: np.ndarray, y: np.ndarray): + def _eval(self, x: np.ndarray, y: np.ndarray, nb_epochs, batch_size): self.model.eval() total_loss = 0 correct = 0 total = 0 - for m in range(len(x)): - inputs = torch.from_numpy(x[m]).to(self._device) - targets = torch.from_numpy(y[m]).to(self._device) - targets = targets.to(self.device) - outputs = self.model(inputs) - loss = self._loss(outputs, targets) - total_loss += (loss.item() * targets.size(0)) - total += targets.size(0) - correct += self.get_step_correct(outputs, targets) + y = check_and_transform_label_format(y, self.nb_classes) + x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True) + y_preprocessed = self.reduce_labels(y_preprocessed) + num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size))) + ind = np.arange(len(x_preprocessed)) + for epoch in range(nb_epochs): + random.shuffle(ind) + for m in range(num_batch): + inputs = torch.from_numpy(x_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device) + targets = torch.from_numpy(y_preprocessed[ind[m * batch_size: (m + 1) * batch_size]]).to(self._device) + targets = targets.to(self.device) + outputs = self.model(inputs) + loss = self._loss(outputs, targets) + total_loss += (loss.item() * targets.size(0)) + total += targets.size(0) + correct += self.get_step_correct(outputs, targets) return total_loss / total, float(correct) / total - def fit(self, X: np.ndarray, Y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, + def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, save_checkpoints: bool = True, **kwargs) -> None: """ - Fit the classifier on the training set `(x, y)`. - - :param X: Training data. - :param Y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of - shape (nb_samples,). - :param batch_size: Size of batches. - :param nb_epochs: Number of epochs to use for training. - :param save_checkpoints: Boolean, save checkpoints if True. - :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch - and providing it takes no effect. - """ + Fit the classifier on the training set `(x, y)`. + :param x: Training data. + :param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels + of shape (nb_samples,). + :param batch_size: Size of batches. + :param nb_epochs: Number of epochs to use for training. + :param save_checkpoints: Boolean, save checkpoints if True. + :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently + supported for PyTorch and providing it takes no effect. + """ # Put the model in the training mode - x, X_test, y, y_test = train_test_split(X, Y, test_size=0.33, random_state=42) self._model.train() if self._optimizer is None: # pragma: no cover @@ -117,12 +123,11 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): correct = self.get_step_correct(model_outputs[-1], o_batch) tot_correct += correct total += o_batch.shape[0] - val_loss, val_acc = self._eval(X_test, y_test) + val_loss, val_acc = self._eval(x, y, num_batch, batch_size) train_acc = float(tot_correct) / total - - if save_checkpoints: - self.save_checkpoint_state_dict(is_best=best_acc <= val_acc) best_acc = max(val_acc, best_acc) + if save_checkpoints: + self.save_checkpoint_model(is_best=best_acc <= val_acc) def save_checkpoint_state_dict(self, is_best: bool, additional_states: Dict = None, filename="latest.tar") -> None: @@ -133,6 +138,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): :param filename: checkpoint name :return: None """ + # add path checkpoint = os.path.join(os.getcwd(), 'checkpoints') path = checkpoint os.makedirs(path, exist_ok=True) @@ -261,7 +267,7 @@ class PyTorchClassifier(PyTorchModel): super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs) self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer) - def fit(self, train_data: Dataset, **kwargs) -> None: + def fit(self, train_data: PytorchData, **kwargs) -> None: """ Fit the model using the training data. diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 16e7d1f..8841303 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -2,7 +2,8 @@ import numpy as np import torch from torch import nn, optim -from apt.utils.datasets import ArrayDataset +from apt.utils.datasets import ArrayDataset, Data, Dataset +from apt.utils.datasets.datasets import PytorchData from apt.utils.models import ModelOutputType from apt.utils.models.pytorch_model import PyTorchClassifier from art.utils import load_nursery @@ -54,14 +55,11 @@ def test_nursery_pytorch(): optimizer = optim.Adam(model.parameters(), lr=0.01) art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, - optimizer=optimizer, input_shape=(24,), - nb_classes=4) - art_model.fit(ArrayDataset(x_train.astype(np.float32), y_train)) + optimizer=optimizer, input_shape=(24,), + nb_classes=4) + art_model.fit(PytorchData(x_train.astype(np.float32), y_train)) pred = np.array([np.argmax(arr) for arr in art_model.predict(ArrayDataset(x_test.astype(np.float32)))]) print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test)) - art_model.load_best_state_dict_checkpoint() - - - + art_model.load_best_model_checkpoint()