This commit is contained in:
olasaadi 2022-07-19 21:16:39 +03:00
parent 07e64b1f86
commit 4973fbebc6
2 changed files with 34 additions and 23 deletions

View file

@ -23,6 +23,7 @@ class PyTorchModel(Model):
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
Wrapper class for pytorch classifier model.
Extension for Pytorch ART model
"""
def get_step_correct(self, outputs, targets) -> int:
@ -187,6 +188,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
if self._optimizer and 'opt_state_dict' in checkpoint:
self._optimizer.load_state_dict(checkpoint['opt_state_dict'])
self.model.eval()
def load_latest_state_dict_checkpoint(self):
"""
@ -266,14 +268,23 @@ class PyTorchClassifier(PyTorchModel):
super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs)
self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer)
def fit(self, train_data: PytorchData, **kwargs) -> None:
def fit(self, train_data: PytorchData, batch_size: int = 128, nb_epochs: int = 10,
save_checkpoints: bool = True, save_entire_model=True, path=os.getcwd(), **kwargs) -> None:
"""
Fit the model using the training data.
:param train_data: Training data.
:type train_data: `Dataset`
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param save_checkpoints: Boolean, save checkpoints if True.
:param save_entire_model: Boolean, save entire model if True, else save state dict.
:param path: path for saving checkpoint.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently
supported for PyTorch and providing it takes no effect.
"""
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), **kwargs)
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), batch_size, nb_epochs,
save_checkpoints, save_entire_model, path, **kwargs)
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""