This commit is contained in:
olasaadi 2022-07-04 12:55:58 +03:00
parent 21cba95a28
commit af7d615628

View file

@ -2,17 +2,13 @@ import logging
import os
import random
import shutil
from typing import Optional, Tuple, Dict
from typing import Optional, Tuple
import numpy as np
from art.utils import check_and_transform_label_format, logger
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import Model, ModelOutputType
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE, Data
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.pytorch import PyTorchClassifier as ArtPyTorchClassifier
import torch
@ -279,9 +275,7 @@ class PyTorchClassifier(PyTorchModel):
:param train_data: Training data.
:type train_data: `Dataset`
"""
encoder = OneHotEncoder(sparse=False)
y_encoded = encoder.fit_transform(train_data.get_labels().reshape(-1, 1))
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)
self._art_model.fit(train_data.get_samples(), train_data.get_labels().reshape(-1, 1), **kwargs)
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""