diff --git a/apt/utils/datasets/datasets.py b/apt/utils/datasets/datasets.py index aff591f..4fa546f 100644 --- a/apt/utils/datasets/datasets.py +++ b/apt/utils/datasets/datasets.py @@ -13,6 +13,7 @@ import urllib.request import numpy as np import pandas as pd import logging +import torch from torch import Tensor logger = logging.getLogger(__name__) @@ -23,6 +24,38 @@ OUTPUT_DATA_ARRAY_TYPE = np.ndarray DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame] +def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE: + """ + converts from INPUT_DATA_ARRAY_TYPE to numpy array + """ + if type(arr) == np.ndarray: + return arr + if type(arr) == pd.DataFrame: + return arr.to_numpy() + if isinstance(arr, list): + return np.array(arr) + if type(arr) == Tensor: + return arr.detach().cpu().numpy() + + raise ValueError('Non supported type: ', type(arr).__name__) + + +def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor: + """ + converts from INPUT_DATA_ARRAY_TYPE to torch tensor array + """ + if type(arr) == np.ndarray: + return torch.from_numpy(arr) + if type(arr) == pd.DataFrame: + return torch.from_numpy(arr.to_numpy()) + if isinstance(arr, list): + return torch.tensor(arr) + if type(arr) == Tensor: + return arr + + raise ValueError('Non supported type: ', type(arr).__name__) + + class Dataset(metaclass=ABCMeta): """Base Abstract Class for Dataset""" @@ -136,30 +169,8 @@ class ArrayDataset(Dataset): :param y: collection of labels (optional) :param kwargs: dataset parameters """ - # convert to numpy - if type(x) == np.ndarray: - self._x = x - elif type(x) == pd.DataFrame: - self._x = x.to_numpy() - elif isinstance(x, list): - self._x = np.array(x) - elif type(x) == Tensor: - self._x = x.numpy() - else: - raise ValueError('Non supported type for x: ', type(x).__name__) - - self._y = None - if y is not None: - if type(y) == np.ndarray: - self._y = y - elif type(y) == pd.DataFrame: - self._y = y.to_numpy() - elif isinstance(y, list): - self._y = np.array(y) - elif type(y) == Tensor: - self._y = y.numpy() - else: - raise ValueError('Non supported type for y: ', type(y).__name__) + self._x = array2numpy(x) + self._y = array2numpy(y) if y is not None else None if y is not None and len(self._x) != len(self._y): raise ValueError('Non equivalent lengths of x and y') @@ -173,6 +184,46 @@ class ArrayDataset(Dataset): return self._y +class PytorchData(Dataset): + + def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs): + """ + PytorchData constructor. + :param x: collection of data samples + :param y: collection of labels (optional) + :param kwargs: dataset parameters + """ + self._x = array2torch_tensor(x) + self._y = array2torch_tensor(y) if y is not None else None + + if y is not None and len(self._x) != len(self._y): + raise ValueError('Non equivalent lengths of x and y') + + if self._y is not None: + self.__getitem__ = self.get_item + else: + self.__getitem__ = self.get_sample_item + + + def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE: + """Return data samples as numpy array""" + return array2numpy(self._x) + + def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE: + """Return labels as numpy array""" + return array2numpy(self._y) if self._y is not None else None + + def get_sample_item(self, idx) -> Tensor: + return self.x[idx] + + def get_item(self, idx) -> Tensor: + sample, label = self.x[idx], self.y[idx] + return sample, label + + def __len__(self): + return len(self.x) + + class DatasetFactory: """Factory class for dataset creation""" registry = {}