mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-26 21:36:22 +02:00
update pytorch wrapper to use torch loaders
fix tests and dataset style
This commit is contained in:
parent
fdc6005fce
commit
c77e34e373
4 changed files with 178 additions and 113 deletions
|
|
@ -39,7 +39,7 @@ def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
|
|||
if type(arr) == Tensor:
|
||||
return arr.detach().cpu().numpy()
|
||||
|
||||
raise ValueError('Non supported type: ', type(arr).__name__)
|
||||
raise ValueError("Non supported type: ", type(arr).__name__)
|
||||
|
||||
|
||||
def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
|
||||
|
|
@ -56,7 +56,7 @@ def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
|
|||
if type(arr) == Tensor:
|
||||
return arr
|
||||
|
||||
raise ValueError('Non supported type: ', type(arr).__name__)
|
||||
raise ValueError("Non supported type: ", type(arr).__name__)
|
||||
|
||||
|
||||
class Dataset(metaclass=ABCMeta):
|
||||
|
|
@ -109,7 +109,7 @@ class StoredDataset(Dataset):
|
|||
os.makedirs(dest_path, exist_ok=True)
|
||||
logger.info("Downloading the dataset...")
|
||||
urllib.request.urlretrieve(url, file_path)
|
||||
logger.info('Dataset Downloaded')
|
||||
logger.info("Dataset Downloaded")
|
||||
|
||||
if unzip:
|
||||
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
|
||||
|
|
@ -156,7 +156,7 @@ class StoredDataset(Dataset):
|
|||
logger.info("Shuffling data")
|
||||
np.random.shuffle(data)
|
||||
|
||||
debug_data = data[:int(len(data) * ratio)]
|
||||
debug_data = data[: int(len(data) * ratio)]
|
||||
logger.info(f"Saving {ratio} of the data to {dest_datafile}")
|
||||
np.savetxt(dest_datafile, debug_data, delimiter=delimiter, fmt=fmt)
|
||||
|
||||
|
|
@ -164,8 +164,13 @@ class StoredDataset(Dataset):
|
|||
class ArrayDataset(Dataset):
|
||||
"""Dataset that is based on x and y arrays (e.g., numpy/pandas/list...)"""
|
||||
|
||||
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None,
|
||||
features_names: Optional = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
x: INPUT_DATA_ARRAY_TYPE,
|
||||
y: Optional[INPUT_DATA_ARRAY_TYPE] = None,
|
||||
features_names: Optional = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
ArrayDataset constructor.
|
||||
:param x: collection of data samples
|
||||
|
|
@ -183,7 +188,7 @@ class ArrayDataset(Dataset):
|
|||
self.features_names = x.columns.to_list()
|
||||
|
||||
if y is not None and len(self._x) != len(self._y):
|
||||
raise ValueError('Non equivalent lengths of x and y')
|
||||
raise ValueError("Non equivalent lengths of x and y")
|
||||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return data samples as numpy array"""
|
||||
|
|
@ -195,7 +200,6 @@ class ArrayDataset(Dataset):
|
|||
|
||||
|
||||
class PytorchData(Dataset):
|
||||
|
||||
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
|
||||
"""
|
||||
PytorchData constructor.
|
||||
|
|
@ -210,15 +214,13 @@ class PytorchData(Dataset):
|
|||
self.features_names = x.columns
|
||||
|
||||
if y is not None and len(self._x) != len(self._y):
|
||||
raise ValueError('Non equivalent lengths of x and y')
|
||||
|
||||
raise ValueError("Non equivalent lengths of x and y")
|
||||
|
||||
if self._y is not None:
|
||||
self.__getitem__ = self.get_item
|
||||
else:
|
||||
self.__getitem__ = self.get_sample_item
|
||||
|
||||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return data samples as numpy array"""
|
||||
return array2numpy(self, self._x)
|
||||
|
|
@ -240,6 +242,7 @@ class PytorchData(Dataset):
|
|||
|
||||
class DatasetFactory:
|
||||
"""Factory class for dataset creation"""
|
||||
|
||||
registry = {}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -252,7 +255,7 @@ class DatasetFactory:
|
|||
|
||||
def inner_wrapper(wrapped_class: Dataset) -> Any:
|
||||
if name in cls.registry:
|
||||
logger.warning('Dataset %s already exists. Will replace it', name)
|
||||
logger.warning("Dataset %s already exists. Will replace it", name)
|
||||
cls.registry[name] = wrapped_class
|
||||
return wrapped_class
|
||||
|
||||
|
|
@ -270,7 +273,7 @@ class DatasetFactory:
|
|||
:return: An instance of the dataset that is created.
|
||||
"""
|
||||
if name not in cls.registry:
|
||||
msg = f'Dataset {name} does not exist in the registry'
|
||||
msg = f"Dataset {name} does not exist in the registry"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue