mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-04 17:22:37 +02:00
fix
This commit is contained in:
parent
07e64b1f86
commit
4973fbebc6
2 changed files with 34 additions and 23 deletions
|
|
@ -23,6 +23,7 @@ class PyTorchModel(Model):
|
|||
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
||||
"""
|
||||
Wrapper class for pytorch classifier model.
|
||||
Extension for Pytorch ART model
|
||||
"""
|
||||
|
||||
def get_step_correct(self, outputs, targets) -> int:
|
||||
|
|
@ -187,6 +188,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
|
||||
if self._optimizer and 'opt_state_dict' in checkpoint:
|
||||
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
|
||||
self.model.eval()
|
||||
|
||||
def load_latest_state_dict_checkpoint(self):
|
||||
"""
|
||||
|
|
@ -266,14 +268,23 @@ class PyTorchClassifier(PyTorchModel):
|
|||
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
|
||||
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
|
||||
|
||||
def fit(self, train_data: PytorchData, **kwargs) -> None:
|
||||
def fit(self, train_data: PytorchData, batch_size: int = 128, nb_epochs: int = 10,
|
||||
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
|
||||
"""
|
||||
Fit the model using the training data.
|
||||
|
||||
:param train_data: Training data.
|
||||
:type train_data: `Dataset`
|
||||
:param batch_size: Size of batches.
|
||||
:param nb_epochs: Number of epochs to use for training.
|
||||
:param save_checkpoints: Boolean, save checkpoints if True.
|
||||
:param save_entire_model: Boolean, save entire model if True, else save state dict.
|
||||
:param path: path for saving checkpoint.
|
||||
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
|
||||
supported for PyTorch and providing it takes no effect.
|
||||
"""
|
||||
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), **kwargs)
|
||||
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), batch_size, nb_epochs,
|
||||
save_checkpoints, save_entire_model, path, **kwargs)
|
||||
|
||||
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue