diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index a7f15e2..729150f 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -302,7 +302,7 @@ class PyTorchClassifier(PyTorchModel): """ Score the model using test data. :param test_data: Test data. - :type train_data: `Dataset` + :type test_data: `Dataset` :return: the score as float (between 0 and 1) """ y = check_and_transform_label_format(test_data.get_labels(), self._art_model.nb_classes)