mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
fix
This commit is contained in:
parent
f484135d84
commit
521c8ce041
1 changed files with 12 additions and 18 deletions
|
|
@ -12,25 +12,18 @@ from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
|||
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
|
||||
import torch
|
||||
|
||||
|
||||
class PyTorchModel(Model):
|
||||
"""
|
||||
Wrapper class for pytorch models.
|
||||
"""
|
||||
|
||||
def score(self, test_data: Dataset, **kwargs):
|
||||
"""
|
||||
Score the model using test data.
|
||||
|
||||
:param test_data: Test data.
|
||||
:type train_data: `Dataset`
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
||||
"""
|
||||
Wrapper class for pytorch classifier model.
|
||||
"""
|
||||
|
||||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None:
|
||||
"""
|
||||
Fit the classifier on the training set `(x, y)`.
|
||||
|
|
@ -79,15 +72,7 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
# Form the loss function
|
||||
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
|
||||
|
||||
# Do training
|
||||
if self._use_amp: # pragma: no cover
|
||||
from apex import amp # pylint: disable=E0611
|
||||
|
||||
with amp.scale_loss(loss, self._optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
|
||||
else:
|
||||
loss.backward()
|
||||
loss.backward()
|
||||
|
||||
self._optimizer.step()
|
||||
|
||||
|
|
@ -142,3 +127,12 @@ class PyTorchClassifier(PyTorchModel):
|
|||
:return: Predictions from the model (class probabilities, if supported).
|
||||
"""
|
||||
return self._art_model.predict(x.get_samples(), **kwargs)
|
||||
|
||||
def score(self, test_data: Dataset, **kwargs):
|
||||
"""
|
||||
Score the model using test data.
|
||||
|
||||
:param test_data: Test data.
|
||||
:type train_data: `Dataset`
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue