add pytorch Dataset

This commit is contained in:
Ron Shmelkin 2022-03-15 15:33:14 +02:00
parent 45cc9180b8
commit f99bf31030

View file

@ -13,6 +13,7 @@ import urllib.request
import numpy as np
import pandas as pd
import logging
import torch
from torch import Tensor
logger = logging.getLogger(__name__)
@ -23,6 +24,38 @@ OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
"""
converts from INPUT_DATA_ARRAY_TYPE to numpy array
"""
if type(arr) == np.ndarray:
return arr
if type(arr) == pd.DataFrame:
return arr.to_numpy()
if isinstance(arr, list):
return np.array(arr)
if type(arr) == Tensor:
return arr.detach().cpu().numpy()
raise ValueError('Non supported type: ', type(arr).__name__)
def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
"""
converts from INPUT_DATA_ARRAY_TYPE to torch tensor array
"""
if type(arr) == np.ndarray:
return torch.from_numpy(arr)
if type(arr) == pd.DataFrame:
return torch.from_numpy(arr.to_numpy())
if isinstance(arr, list):
return torch.tensor(arr)
if type(arr) == Tensor:
return arr
raise ValueError('Non supported type: ', type(arr).__name__)
class Dataset(metaclass=ABCMeta):
"""Base Abstract Class for Dataset"""
@ -136,30 +169,8 @@ class ArrayDataset(Dataset):
:param y: collection of labels (optional)
:param kwargs: dataset parameters
"""
# convert to numpy
if type(x) == np.ndarray:
self._x = x
elif type(x) == pd.DataFrame:
self._x = x.to_numpy()
elif isinstance(x, list):
self._x = np.array(x)
elif type(x) == Tensor:
self._x = x.numpy()
else:
raise ValueError('Non supported type for x: ', type(x).__name__)
self._y = None
if y is not None:
if type(y) == np.ndarray:
self._y = y
elif type(y) == pd.DataFrame:
self._y = y.to_numpy()
elif isinstance(y, list):
self._y = np.array(y)
elif type(y) == Tensor:
self._y = y.numpy()
else:
raise ValueError('Non supported type for y: ', type(y).__name__)
self._x = array2numpy(x)
self._y = array2numpy(y) if y is not None else None
if y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')
@ -173,6 +184,46 @@ class ArrayDataset(Dataset):
return self._y
class PytorchData(Dataset):
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
"""
PytorchData constructor.
:param x: collection of data samples
:param y: collection of labels (optional)
:param kwargs: dataset parameters
"""
self._x = array2torch_tensor(x)
self._y = array2torch_tensor(y) if y is not None else None
if y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')
if self._y is not None:
self.__getitem__ = self.get_item
else:
self.__getitem__ = self.get_sample_item
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return array2numpy(self._x)
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return labels as numpy array"""
return array2numpy(self._y) if self._y is not None else None
def get_sample_item(self, idx) -> Tensor:
return self.x[idx]
def get_item(self, idx) -> Tensor:
sample, label = self.x[idx], self.y[idx]
return sample, label
def __len__(self):
return len(self.x)
class DatasetFactory:
"""Factory class for dataset creation"""
registry = {}