This commit is contained in:
olasaadi 2022-05-17 16:57:21 +03:00
parent f484135d84
commit 521c8ce041

View file

@ -12,25 +12,18 @@ from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
import torch
class PyTorchModel(Model):
"""
Wrapper class for pytorch models.
"""
def score(self, test_data: Dataset, **kwargs):
"""
Score the model using test data.
:param test_data: Test data.
:type train_data: `Dataset`
"""
pass
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
"""
Wrapper class for pytorch classifier model.
"""
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None:
"""
Fit the classifier on the training set `(x, y)`.
@ -79,15 +72,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
# Form the loss function
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
# Do training
if self._use_amp: # pragma: no cover
from apex import amp # pylint: disable=E0611
with amp.scale_loss(loss, self._optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
loss.backward()
self._optimizer.step()
@ -142,3 +127,12 @@ class PyTorchClassifier(PyTorchModel):
:return: Predictions from the model (class probabilities, if supported).
"""
return self._art_model.predict(x.get_samples(), **kwargs)
def score(self, test_data: Dataset, **kwargs):
"""
Score the model using test data.
:param test_data: Test data.
:type train_data: `Dataset`
"""
pass