diff --git a/apt/utils/models/pytorch_model.py b/apt/utils/models/pytorch_model.py index a57830b..3805a86 100644 --- a/apt/utils/models/pytorch_model.py +++ b/apt/utils/models/pytorch_model.py @@ -2,17 +2,13 @@ import logging import os import random import shutil -from typing import Optional, Tuple, Dict +from typing import Optional, Tuple import numpy as np from art.utils import check_and_transform_label_format, logger -from sklearn.model_selection import train_test_split - -from sklearn.preprocessing import OneHotEncoder - from apt.utils.datasets.datasets import PytorchData from apt.utils.models import Model, ModelOutputType -from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE, Data +from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier import torch @@ -279,9 +275,7 @@ class PyTorchClassifier(PyTorchModel): :param train_data: Training data. :type train_data: `Dataset` """ - encoder = OneHotEncoder(sparse=False) - y_encoded = encoder.fit_transform(train_data.get_labels().reshape(-1, 1)) - self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs) + self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), **kwargs) def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: """