diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index 963cf31..9659b7c 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -12,25 +12,18 @@ from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier import torch + class PyTorchModel(Model): """ Wrapper class for pytorch models. """ - def score(self, test_data: Dataset, **kwargs): - """ - Score the model using test data. - - :param test_data: Test data. - :type train_data: `Dataset` - """ - pass - class PyTorchClassifierWrapper(ArtPyTorchClassifier): """ Wrapper class for pytorch classifier model. """ + def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None: """ Fit the classifier on the training set `(x, y)`. @@ -79,15 +72,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier): # Form the loss function loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable] - # Do training - if self._use_amp: # pragma: no cover - from apex import amp # pylint: disable=E0611 - - with amp.scale_loss(loss, self._optimizer) as scaled_loss: - scaled_loss.backward() - - else: - loss.backward() + loss.backward() self._optimizer.step() @@ -142,3 +127,12 @@ class PyTorchClassifier(PyTorchModel): :return: Predictions from the model (class probabilities, if supported). """ return self._art_model.predict(x.get_samples(), **kwargs) + + def score(self, test_data: Dataset, **kwargs): + """ + Score the model using test data. + + :param test_data: Test data. + :type train_data: `Dataset` + """ + pass