mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-04 17:22:37 +02:00
Increase version to 0.2.0 (#74)
* Remove tensorflow dependency if not using keras model * Remove xgboost dependency if not using xgboost model * Documentation updates Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
782edabd58
commit
8a9ef80146
25 changed files with 306 additions and 152 deletions
|
|
@ -31,7 +31,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
"""
|
||||
|
||||
def get_step_correct(self, outputs, targets) -> int:
|
||||
"""get number of correctly classified labels"""
|
||||
"""
|
||||
Get number of correctly classified labels.
|
||||
"""
|
||||
if len(outputs) != len(targets):
|
||||
raise ValueError("outputs and targets should be the same length.")
|
||||
if self.nb_classes > 1:
|
||||
|
|
@ -40,7 +42,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
|
||||
|
||||
def _eval(self, loader: DataLoader):
|
||||
"""inner function for model evaluation"""
|
||||
"""
|
||||
Inner function for model evaluation.
|
||||
"""
|
||||
self.model.eval()
|
||||
|
||||
total_loss = 0
|
||||
|
|
@ -74,19 +78,20 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
) -> None:
|
||||
"""
|
||||
Fit the classifier on the training set `(x, y)`.
|
||||
|
||||
:param x: Training data.
|
||||
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
|
||||
of shape (nb_samples,).
|
||||
of shape (nb_samples,).
|
||||
:param x_validation: Validation data (optional).
|
||||
:param y_validation: Target validation values (class labels) one-hot-encoded of shape
|
||||
(nb_samples, nb_classes) or index labels of shape (nb_samples,) (optional).
|
||||
(nb_samples, nb_classes) or index labels of shape (nb_samples,) (optional).
|
||||
:param batch_size: Size of batches.
|
||||
:param nb_epochs: Number of epochs to use for training.
|
||||
:param save_checkpoints: Boolean, save checkpoints if True.
|
||||
:param save_entire_model: Boolean, save entire model if True, else save state dict.
|
||||
:param path: path for saving checkpoint.
|
||||
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
|
||||
supported for PyTorch and providing it takes no effect.
|
||||
supported for PyTorch and providing it takes no effect.
|
||||
"""
|
||||
# Put the model in the training mode
|
||||
self._model.train()
|
||||
|
|
@ -153,7 +158,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
|
||||
"""
|
||||
Saves checkpoint as latest.tar or best.tar
|
||||
Saves checkpoint as latest.tar or best.tar.
|
||||
|
||||
:param is_best: whether the model is the best achieved model
|
||||
:param path: path for saving checkpoint
|
||||
:param filename: checkpoint name
|
||||
|
|
@ -176,7 +182,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
|
||||
"""
|
||||
Saves checkpoint as latest.tar or best.tar
|
||||
Saves checkpoint as latest.tar or best.tar.
|
||||
|
||||
:param is_best: whether the model is the best achieved model
|
||||
:param path: path for saving checkpoint
|
||||
:param filename: checkpoint name
|
||||
|
|
@ -194,7 +201,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
|
||||
"""
|
||||
Load model only based on the check point path
|
||||
Load model only based on the check point path.
|
||||
|
||||
:param model_name: check point filename
|
||||
:param path: checkpoint path (default current work dir)
|
||||
:return: loaded model
|
||||
|
|
@ -219,21 +227,24 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
def load_latest_state_dict_checkpoint(self):
|
||||
"""
|
||||
Load model state dict only based on the check point path (latest.tar)
|
||||
Load model state dict only based on the check point path (latest.tar).
|
||||
|
||||
:return: loaded model
|
||||
"""
|
||||
self.load_checkpoint_state_dict_by_path("latest.tar")
|
||||
|
||||
def load_best_state_dict_checkpoint(self):
|
||||
"""
|
||||
Load model state dict only based on the check point path (model_best.tar)
|
||||
Load model state dict only based on the check point path (model_best.tar).
|
||||
|
||||
:return: loaded model
|
||||
"""
|
||||
self.load_checkpoint_state_dict_by_path("model_best.tar")
|
||||
|
||||
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
|
||||
"""
|
||||
Load model only based on the check point path
|
||||
Load model only based on the check point path.
|
||||
|
||||
:param model_name: check point filename
|
||||
:param path: checkpoint path (default current work dir)
|
||||
:return: loaded model
|
||||
|
|
@ -254,14 +265,16 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
def load_latest_model_checkpoint(self):
|
||||
"""
|
||||
Load entire model only based on the check point path (latest.tar)
|
||||
Load entire model only based on the check point path (latest.tar).
|
||||
|
||||
:return: loaded model
|
||||
"""
|
||||
self.load_checkpoint_model_by_path("latest.tar")
|
||||
|
||||
def load_best_model_checkpoint(self):
|
||||
"""
|
||||
Load entire model only based on the check point path (model_best.tar)
|
||||
Load entire model only based on the check point path (model_best.tar).
|
||||
|
||||
:return: loaded model
|
||||
"""
|
||||
self.load_checkpoint_model_by_path("model_best.tar")
|
||||
|
|
@ -288,11 +301,11 @@ class PyTorchClassifier(PyTorchModel):
|
|||
Initialization specifically for the PyTorch-based implementation.
|
||||
|
||||
:param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits
|
||||
output should be preferred where possible to ensure attack efficiency.
|
||||
output should be preferred where possible to ensure attack efficiency.
|
||||
:param output_type: The type of output the model yields (vector/label only for classifiers,
|
||||
value for regressors)
|
||||
:param loss: The loss function for which to compute gradients for training. The target label must be raw
|
||||
categorical, i.e. not converted to one-hot encoding.
|
||||
categorical, i.e. not converted to one-hot encoding.
|
||||
:param input_shape: The shape of one input instance.
|
||||
:param optimizer: The optimizer used to train the classifier.
|
||||
:param black_box_access: Boolean describing the type of deployment of the model (when in production).
|
||||
|
|
@ -311,7 +324,7 @@ class PyTorchClassifier(PyTorchModel):
|
|||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
The pytorch model's loss function
|
||||
The pytorch model's loss function.
|
||||
|
||||
:return: The pytorch model's loss function
|
||||
"""
|
||||
|
|
@ -320,7 +333,7 @@ class PyTorchClassifier(PyTorchModel):
|
|||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
The pytorch model's optimizer
|
||||
The pytorch model's optimizer.
|
||||
|
||||
:return: The pytorch model's optimizer
|
||||
"""
|
||||
|
|
@ -350,7 +363,7 @@ class PyTorchClassifier(PyTorchModel):
|
|||
:param save_entire_model: Boolean, save entire model if True, else save state dict.
|
||||
:param path: path for saving checkpoint.
|
||||
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
|
||||
supported for PyTorch and providing it takes no effect.
|
||||
supported for PyTorch and providing it takes no effect.
|
||||
"""
|
||||
if validation_data is None:
|
||||
self._art_model.fit(
|
||||
|
|
@ -390,6 +403,7 @@ class PyTorchClassifier(PyTorchModel):
|
|||
def score(self, test_data: PytorchData, **kwargs):
|
||||
"""
|
||||
Score the model using test data.
|
||||
|
||||
:param test_data: Test data.
|
||||
:type test_data: `PytorchData`
|
||||
:return: the score as float (between 0 and 1)
|
||||
|
|
@ -400,7 +414,8 @@ class PyTorchClassifier(PyTorchModel):
|
|||
|
||||
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
|
||||
"""
|
||||
Load model only based on the check point path
|
||||
Load model only based on the check point path.
|
||||
|
||||
:param model_name: check point filename
|
||||
:param path: checkpoint path (default current work dir)
|
||||
:return: loaded model
|
||||
|
|
@ -409,21 +424,24 @@ class PyTorchClassifier(PyTorchModel):
|
|||
|
||||
def load_latest_state_dict_checkpoint(self):
|
||||
"""
|
||||
Load model state dict only based on the check point path (latest.tar)
|
||||
Load model state dict only based on the check point path (latest.tar).
|
||||
|
||||
:return: loaded model
|
||||
"""
|
||||
self._art_model.load_latest_state_dict_checkpoint()
|
||||
|
||||
def load_best_state_dict_checkpoint(self):
|
||||
"""
|
||||
Load model state dict only based on the check point path (model_best.tar)
|
||||
Load model state dict only based on the check point path (model_best.tar).
|
||||
|
||||
:return: loaded model
|
||||
"""
|
||||
self._art_model.load_best_state_dict_checkpoint()
|
||||
|
||||
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
|
||||
"""
|
||||
Load model only based on the check point path
|
||||
Load model only based on the check point path.
|
||||
|
||||
:param model_name: check point filename
|
||||
:param path: checkpoint path (default current work dir)
|
||||
:return: loaded model
|
||||
|
|
@ -432,14 +450,16 @@ class PyTorchClassifier(PyTorchModel):
|
|||
|
||||
def load_latest_model_checkpoint(self):
|
||||
"""
|
||||
Load entire model only based on the check point path (latest.tar)
|
||||
Load entire model only based on the check point path (latest.tar).
|
||||
|
||||
:return: loaded model
|
||||
"""
|
||||
self._art_model.load_latest_model_checkpoint()
|
||||
|
||||
def load_best_model_checkpoint(self):
|
||||
"""
|
||||
Load entire model only based on the check point path (model_best.tar)
|
||||
Load entire model only based on the check point path (model_best.tar).
|
||||
|
||||
:return: loaded model
|
||||
"""
|
||||
self._art_model.load_best_model_checkpoint()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue