diff --git a/apt/utils/datasets/datasets.py b/apt/utils/datasets/datasets.py index 5521bba..2e0a70a 100644 --- a/apt/utils/datasets/datasets.py +++ b/apt/utils/datasets/datasets.py @@ -273,9 +273,9 @@ class DatasetWithPredictions(Dataset): y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names: Optional[list] = None, **kwargs): self.is_pandas = False self.features_names = features_names - self._pred = self._array2numpy(pred) - self._y = self._array2numpy(y) if y is not None else None - self._x = self._array2numpy(x) if x is not None else None + self._pred = array2numpy(pred) + self._y = array2numpy(y) if y is not None else None + self._x = array2numpy(x) if x is not None else None if self.is_pandas and x is not None: if features_names and not np.array_equal(features_names, x.columns): raise ValueError("The supplied features are not the same as in the data features") diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index c1f36c4..44e9013 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -54,7 +54,7 @@ def test_nursery_pytorch_state_dict(): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(inner_model.parameters(), lr=0.01) - model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, + model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion, optimizer=optimizer, input_shape=(24,), nb_classes=4) model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10) @@ -84,7 +84,7 @@ def test_nursery_pytorch_save_entire_model(): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.01) - art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion, + art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion, optimizer=optimizer, input_shape=(24,), nb_classes=4) art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)