From fdc6005fce105b120fd84a78ec3c5c52b00fb168 Mon Sep 17 00:00:00 2001 From: olasaadi Date: Fri, 22 Jul 2022 01:01:45 +0300 Subject: [PATCH] add validation set --- apt/utils/models/pytorch_model.py | 34 +++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 729150f..f245ad4 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -8,7 +8,7 @@ import numpy as np from art.utils import check_and_transform_label_format, logger from apt.utils.datasets.datasets import PytorchData from apt.utils.models import Model, ModelOutputType -from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE +from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier import torch @@ -61,13 +61,17 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): return total_loss / total, float(correct) / total - def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, - save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None: + def fit(self, x: np.ndarray, y: np.ndarray, x_validation: np.ndarray = None, y_validation: np.ndarray = None, + batch_size: int = 128, nb_epochs: int = 10, save_checkpoints: bool = True, save_entire_model=True, + path=os.getcwd(), **kwargs) -> None: """ Fit the classifier on the training set `(x, y)`. :param x: Training data. :param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of shape (nb_samples,). + :param x_validation: Validation data (optional). + :param y_validation: Target validation values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels + of shape (nb_samples,) (optional). :param batch_size: Size of batches. :param nb_epochs: Number of epochs to use for training. :param save_checkpoints: Boolean, save checkpoints if True. @@ -83,6 +87,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): raise ValueError("An optimizer is needed to train the model, but none for provided.") _y = check_and_transform_label_format(y, self.nb_classes) + if x_validation is None or y_validation is None: + x_validation = x + y_validation = y # Apply preprocessing x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True) @@ -121,7 +128,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): correct = self.get_step_correct(model_outputs[-1], o_batch) tot_correct += correct total += o_batch.shape[0] - val_loss, val_acc = self._eval(x, y, num_batch, batch_size) + + val_loss, val_acc = self._eval(x_validation, y_validation, num_batch, batch_size) print(val_acc) best_acc = max(val_acc, best_acc) if save_checkpoints: @@ -270,13 +278,16 @@ class PyTorchClassifier(PyTorchModel): super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs) self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer) - def fit(self, train_data: PytorchData, batch_size: int = 128, nb_epochs: int = 10, + def fit(self, train_data: PytorchData, validation_data: PytorchData = None, batch_size: int = 128, + nb_epochs: int = 10, save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None: """ Fit the model using the training data. :param train_data: Training data. - :type train_data: `Dataset` + :type train_data: `PytorchData` + :param validation_data: Training data. + :type train_data: `PytorchData` :param batch_size: Size of batches. :param nb_epochs: Number of epochs to use for training. :param save_checkpoints: Boolean, save checkpoints if True. @@ -285,10 +296,11 @@ class PyTorchClassifier(PyTorchModel): :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch and providing it takes no effect. """ - self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), batch_size, nb_epochs, - save_checkpoints, save_entire_model, path, **kwargs) + self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), + validation_data.get_samples(), validation_data.get_labels().reshape(-1, 1), batch_size, + nb_epochs, save_checkpoints, save_entire_model, path, **kwargs) - def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: + def predict(self, x: PytorchData, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: """ Perform predictions using the model for input `x`. @@ -298,11 +310,11 @@ class PyTorchClassifier(PyTorchModel): """ return self._art_model.predict(x.get_samples(), **kwargs) - def score(self, test_data: Dataset, **kwargs): + def score(self, test_data: PytorchData, **kwargs): """ Score the model using test data. :param test_data: Test data. - :type test_data: `Dataset` + :type test_data: `PytorchData` :return: the score as float (between 0 and 1) """ y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes)