From 44d012857f6e0bb6c5617e5526673b10213d9e26 Mon Sep 17 00:00:00 2001 From: abigailt Date: Wed, 19 Oct 2022 17:10:42 +0300 Subject: [PATCH] Add loss and optimizer as properties Signed-off-by: abigailt --- apt/utils/models/pytorch_model.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index f3eb86c..3eb130c 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -304,8 +304,28 @@ class PyTorchClassifier(PyTorchModel): queries that can be submitted. Optional, Default is True. """ super().__init__(model, output_type, black_box_access, unlimited_queries, **kwargs) + self._loss = loss + self._optimizer = optimizer self._art_model = PyTorchClassifierWrapper(model, loss, input_shape, nb_classes, optimizer) + @property + def loss(self): + """ + The pytorch model's loss function + + :return: The pytorch model's loss function + """ + return self._loss + + @property + def optimizer(self): + """ + The pytorch model's optimizer + + :return: The pytorch model's optimizer + """ + return self._optimizer + def fit( self, train_data: PytorchData,