Increase version to 0.2.0 (#74)

* Remove tensorflow dependency if not using keras model
* Remove xgboost dependency if not using xgboost model
* Documentation updates

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailgold 2023-05-08 12:50:55 +03:00 committed by GitHub
parent 782edabd58
commit 8a9ef80146
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 306 additions and 152 deletions

View file

@ -31,7 +31,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
def get_step_correct(self, outputs, targets) -> int:
"""get number of correctly classified labels"""
"""
Get number of correctly classified labels.
"""
if len(outputs) != len(targets):
raise ValueError("outputs and targets should be the same length.")
if self.nb_classes > 1:
@ -40,7 +42,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
return int(torch.sum(torch.round(outputs, axis=-1) == targets).item())
def _eval(self, loader: DataLoader):
"""inner function for model evaluation"""
"""
Inner function for model evaluation.
"""
self.model.eval()
total_loss = 0
@ -74,19 +78,20 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
) -> None:
"""
Fit the classifier on the training set `(x, y)`.
:param x: Training data.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels
of shape (nb_samples,).
of shape (nb_samples,).
:param x_validation: Validation data (optional).
:param y_validation: Target validation values (class labels) one-hot-encoded of shape
(nb_samples, nb_classes) or index labels of shape (nb_samples,) (optional).
(nb_samples, nb_classes) or index labels of shape (nb_samples,) (optional).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param save_entire_model: Boolean, save entire model if True, else save state dict.
:param path: path for saving checkpoint.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
supported for PyTorch and providing it takes no effect.
"""
# Put the model in the training mode
self._model.train()
@ -153,7 +158,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
def save_checkpoint_state_dict(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
Saves checkpoint as latest.tar or best.tar.
:param is_best: whether the model is the best achieved model
:param path: path for saving checkpoint
:param filename: checkpoint name
@ -176,7 +182,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
def save_checkpoint_model(self, is_best: bool, path=os.getcwd(), filename="latest.tar") -> None:
"""
Saves checkpoint as latest.tar or best.tar
Saves checkpoint as latest.tar or best.tar.
:param is_best: whether the model is the best achieved model
:param path: path for saving checkpoint
:param filename: checkpoint name
@ -194,7 +201,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
Load model only based on the check point path.
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
@ -219,21 +227,24 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
def load_latest_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (latest.tar)
Load model state dict only based on the check point path (latest.tar).
:return: loaded model
"""
self.load_checkpoint_state_dict_by_path("latest.tar")
def load_best_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (model_best.tar)
Load model state dict only based on the check point path (model_best.tar).
:return: loaded model
"""
self.load_checkpoint_state_dict_by_path("model_best.tar")
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
Load model only based on the check point path.
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
@ -254,14 +265,16 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
def load_latest_model_checkpoint(self):
"""
Load entire model only based on the check point path (latest.tar)
Load entire model only based on the check point path (latest.tar).
:return: loaded model
"""
self.load_checkpoint_model_by_path("latest.tar")
def load_best_model_checkpoint(self):
"""
Load entire model only based on the check point path (model_best.tar)
Load entire model only based on the check point path (model_best.tar).
:return: loaded model
"""
self.load_checkpoint_model_by_path("model_best.tar")
@ -288,11 +301,11 @@ class PyTorchClassifier(PyTorchModel):
Initialization specifically for the PyTorch-based implementation.
:param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits
output should be preferred where possible to ensure attack efficiency.
output should be preferred where possible to ensure attack efficiency.
:param output_type: The type of output the model yields (vector/label only for classifiers,
value for regressors)
:param loss: The loss function for which to compute gradients for training. The target label must be raw
categorical, i.e. not converted to one-hot encoding.
categorical, i.e. not converted to one-hot encoding.
:param input_shape: The shape of one input instance.
:param optimizer: The optimizer used to train the classifier.
:param black_box_access: Boolean describing the type of deployment of the model (when in production).
@ -311,7 +324,7 @@ class PyTorchClassifier(PyTorchModel):
@property
def loss(self):
"""
The pytorch model's loss function
The pytorch model's loss function.
:return: The pytorch model's loss function
"""
@ -320,7 +333,7 @@ class PyTorchClassifier(PyTorchModel):
@property
def optimizer(self):
"""
The pytorch model's optimizer
The pytorch model's optimizer.
:return: The pytorch model's optimizer
"""
@ -350,7 +363,7 @@ class PyTorchClassifier(PyTorchModel):
:param save_entire_model: Boolean, save entire model if True, else save state dict.
:param path: path for saving checkpoint.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
supported for PyTorch and providing it takes no effect.
"""
if validation_data is None:
self._art_model.fit(
@ -390,6 +403,7 @@ class PyTorchClassifier(PyTorchModel):
def score(self, test_data: PytorchData, **kwargs):
"""
Score the model using test data.
:param test_data: Test data.
:type test_data: `PytorchData`
:return: the score as float (between 0 and 1)
@ -400,7 +414,8 @@ class PyTorchClassifier(PyTorchModel):
def load_checkpoint_state_dict_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
Load model only based on the check point path.
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
@ -409,21 +424,24 @@ class PyTorchClassifier(PyTorchModel):
def load_latest_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (latest.tar)
Load model state dict only based on the check point path (latest.tar).
:return: loaded model
"""
self._art_model.load_latest_state_dict_checkpoint()
def load_best_state_dict_checkpoint(self):
"""
Load model state dict only based on the check point path (model_best.tar)
Load model state dict only based on the check point path (model_best.tar).
:return: loaded model
"""
self._art_model.load_best_state_dict_checkpoint()
def load_checkpoint_model_by_path(self, model_name: str, path: str = None):
"""
Load model only based on the check point path
Load model only based on the check point path.
:param model_name: check point filename
:param path: checkpoint path (default current work dir)
:return: loaded model
@ -432,14 +450,16 @@ class PyTorchClassifier(PyTorchModel):
def load_latest_model_checkpoint(self):
"""
Load entire model only based on the check point path (latest.tar)
Load entire model only based on the check point path (latest.tar).
:return: loaded model
"""
self._art_model.load_latest_model_checkpoint()
def load_best_model_checkpoint(self):
"""
Load entire model only based on the check point path (model_best.tar)
Load entire model only based on the check point path (model_best.tar).
:return: loaded model
"""
self._art_model.load_best_model_checkpoint()