mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
add validation set
This commit is contained in:
parent
65388da605
commit
fdc6005fce
1 changed files with 23 additions and 11 deletions
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
from art.utils import check_and_transform_label_format, logger
|
||||
from apt.utils.datasets.datasets import PytorchData
|
||||
from apt.utils.models import Model, ModelOutputType
|
||||
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
||||
from apt.utils.datasets import OUTPUT_DATA_ARRAY_TYPE
|
||||
|
||||
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
|
||||
import torch
|
||||
|
|
@ -61,13 +61,17 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
return total_loss / total, float(correct) / total
|
||||
|
||||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
|
||||
def fit(self, x: np.ndarray, y: np.ndarray, x_validation: np.ndarray = None, y_validation: np.ndarray = None,
|
||||
batch_size: int = 128, nb_epochs: int = 10, save_checkpoints: bool = True, save_entire_model=True,
|
||||
path=os.getcwd(), **kwargs) -> None:
|
||||
"""
|
||||
Fit the classifier on the training set `(x, y)`.
|
||||
:param x: Training data.
|
||||
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
|
||||
of shape (nb_samples,).
|
||||
:param x_validation: Validation data (optional).
|
||||
:param y_validation: Target validation values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
|
||||
of shape (nb_samples,) (optional).
|
||||
:param batch_size: Size of batches.
|
||||
:param nb_epochs: Number of epochs to use for training.
|
||||
:param save_checkpoints: Boolean, save checkpoints if True.
|
||||
|
|
@ -83,6 +87,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
raise ValueError("An optimizer is needed to train the model, but none for provided.")
|
||||
|
||||
_y = check_and_transform_label_format(y, self.nb_classes)
|
||||
if x_validation is None or y_validation is None:
|
||||
x_validation = x
|
||||
y_validation = y
|
||||
|
||||
# Apply preprocessing
|
||||
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True)
|
||||
|
|
@ -121,7 +128,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
correct = self.get_step_correct(model_outputs[-1], o_batch)
|
||||
tot_correct += correct
|
||||
total += o_batch.shape[0]
|
||||
val_loss, val_acc = self._eval(x, y, num_batch, batch_size)
|
||||
|
||||
val_loss, val_acc = self._eval(x_validation, y_validation, num_batch, batch_size)
|
||||
print(val_acc)
|
||||
best_acc = max(val_acc, best_acc)
|
||||
if save_checkpoints:
|
||||
|
|
@ -270,13 +278,16 @@ class PyTorchClassifier(PyTorchModel):
|
|||
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
|
||||
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
|
||||
|
||||
def fit(self, train_data: PytorchData, batch_size: int = 128, nb_epochs: int = 10,
|
||||
def fit(self, train_data: PytorchData, validation_data: PytorchData = None, batch_size: int = 128,
|
||||
nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
|
||||
"""
|
||||
Fit the model using the training data.
|
||||
|
||||
:param train_data: Training data.
|
||||
:type train_data: `Dataset`
|
||||
:type train_data: `PytorchData`
|
||||
:param validation_data: Training data.
|
||||
:type train_data: `PytorchData`
|
||||
:param batch_size: Size of batches.
|
||||
:param nb_epochs: Number of epochs to use for training.
|
||||
:param save_checkpoints: Boolean, save checkpoints if True.
|
||||
|
|
@ -285,10 +296,11 @@ class PyTorchClassifier(PyTorchModel):
|
|||
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
|
||||
supported for PyTorch and providing it takes no effect.
|
||||
"""
|
||||
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), batch_size, nb_epochs,
|
||||
save_checkpoints, save_entire_model, path, **kwargs)
|
||||
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1),
|
||||
validation_data.get_samples(), validation_data.get_labels().reshape(-1, 1), batch_size,
|
||||
nb_epochs, save_checkpoints, save_entire_model, path, **kwargs)
|
||||
|
||||
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
def predict(self, x: PytorchData, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""
|
||||
Perform predictions using the model for input `x`.
|
||||
|
||||
|
|
@ -298,11 +310,11 @@ class PyTorchClassifier(PyTorchModel):
|
|||
"""
|
||||
return self._art_model.predict(x.get_samples(), **kwargs)
|
||||
|
||||
def score(self, test_data: Dataset, **kwargs):
|
||||
def score(self, test_data: PytorchData, **kwargs):
|
||||
"""
|
||||
Score the model using test data.
|
||||
:param test_data: Test data.
|
||||
:type test_data: `Dataset`
|
||||
:type test_data: `PytorchData`
|
||||
:return: the score as float (between 0 and 1)
|
||||
"""
|
||||
y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue