mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
add pytorch Dataset
This commit is contained in:
parent
45cc9180b8
commit
f99bf31030
1 changed files with 75 additions and 24 deletions
|
|
@ -13,6 +13,7 @@ import urllib.request
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -23,6 +24,38 @@ OUTPUT_DATA_ARRAY_TYPE = np.ndarray
|
|||
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
|
||||
|
||||
|
||||
def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""
|
||||
converts from INPUT_DATA_ARRAY_TYPE to numpy array
|
||||
"""
|
||||
if type(arr) == np.ndarray:
|
||||
return arr
|
||||
if type(arr) == pd.DataFrame:
|
||||
return arr.to_numpy()
|
||||
if isinstance(arr, list):
|
||||
return np.array(arr)
|
||||
if type(arr) == Tensor:
|
||||
return arr.detach().cpu().numpy()
|
||||
|
||||
raise ValueError('Non supported type: ', type(arr).__name__)
|
||||
|
||||
|
||||
def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
|
||||
"""
|
||||
converts from INPUT_DATA_ARRAY_TYPE to torch tensor array
|
||||
"""
|
||||
if type(arr) == np.ndarray:
|
||||
return torch.from_numpy(arr)
|
||||
if type(arr) == pd.DataFrame:
|
||||
return torch.from_numpy(arr.to_numpy())
|
||||
if isinstance(arr, list):
|
||||
return torch.tensor(arr)
|
||||
if type(arr) == Tensor:
|
||||
return arr
|
||||
|
||||
raise ValueError('Non supported type: ', type(arr).__name__)
|
||||
|
||||
|
||||
class Dataset(metaclass=ABCMeta):
|
||||
"""Base Abstract Class for Dataset"""
|
||||
|
||||
|
|
@ -136,30 +169,8 @@ class ArrayDataset(Dataset):
|
|||
:param y: collection of labels (optional)
|
||||
:param kwargs: dataset parameters
|
||||
"""
|
||||
# convert to numpy
|
||||
if type(x) == np.ndarray:
|
||||
self._x = x
|
||||
elif type(x) == pd.DataFrame:
|
||||
self._x = x.to_numpy()
|
||||
elif isinstance(x, list):
|
||||
self._x = np.array(x)
|
||||
elif type(x) == Tensor:
|
||||
self._x = x.numpy()
|
||||
else:
|
||||
raise ValueError('Non supported type for x: ', type(x).__name__)
|
||||
|
||||
self._y = None
|
||||
if y is not None:
|
||||
if type(y) == np.ndarray:
|
||||
self._y = y
|
||||
elif type(y) == pd.DataFrame:
|
||||
self._y = y.to_numpy()
|
||||
elif isinstance(y, list):
|
||||
self._y = np.array(y)
|
||||
elif type(y) == Tensor:
|
||||
self._y = y.numpy()
|
||||
else:
|
||||
raise ValueError('Non supported type for y: ', type(y).__name__)
|
||||
self._x = array2numpy(x)
|
||||
self._y = array2numpy(y) if y is not None else None
|
||||
|
||||
if y is not None and len(self._x) != len(self._y):
|
||||
raise ValueError('Non equivalent lengths of x and y')
|
||||
|
|
@ -173,6 +184,46 @@ class ArrayDataset(Dataset):
|
|||
return self._y
|
||||
|
||||
|
||||
class PytorchData(Dataset):
|
||||
|
||||
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
|
||||
"""
|
||||
PytorchData constructor.
|
||||
:param x: collection of data samples
|
||||
:param y: collection of labels (optional)
|
||||
:param kwargs: dataset parameters
|
||||
"""
|
||||
self._x = array2torch_tensor(x)
|
||||
self._y = array2torch_tensor(y) if y is not None else None
|
||||
|
||||
if y is not None and len(self._x) != len(self._y):
|
||||
raise ValueError('Non equivalent lengths of x and y')
|
||||
|
||||
if self._y is not None:
|
||||
self.__getitem__ = self.get_item
|
||||
else:
|
||||
self.__getitem__ = self.get_sample_item
|
||||
|
||||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return data samples as numpy array"""
|
||||
return array2numpy(self._x)
|
||||
|
||||
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return labels as numpy array"""
|
||||
return array2numpy(self._y) if self._y is not None else None
|
||||
|
||||
def get_sample_item(self, idx) -> Tensor:
|
||||
return self.x[idx]
|
||||
|
||||
def get_item(self, idx) -> Tensor:
|
||||
sample, label = self.x[idx], self.y[idx]
|
||||
return sample, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
|
||||
class DatasetFactory:
|
||||
"""Factory class for dataset creation"""
|
||||
registry = {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue