mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-26 21:36:22 +02:00
Renaming
This commit is contained in:
parent
3d82db80c4
commit
f2df2fcc8c
6 changed files with 35 additions and 43 deletions
|
|
@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
|
|||
DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame]
|
||||
|
||||
|
||||
class DatasetABC(metaclass=ABCMeta):
|
||||
class Dataset(metaclass=ABCMeta):
|
||||
"""Base Abstract Class for Dataset"""
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -38,7 +38,7 @@ class DatasetABC(metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
|
||||
class StoredDatasetABC(DatasetABC):
|
||||
class StoredDataset(Dataset):
|
||||
"""Abstract Class for Storable Dataset"""
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -73,7 +73,7 @@ class StoredDatasetABC(DatasetABC):
|
|||
logger.info('Dataset Downloaded')
|
||||
|
||||
if unzip:
|
||||
StoredDatasetABC.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
|
||||
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -123,12 +123,12 @@ class StoredDatasetABC(DatasetABC):
|
|||
np.savetxt(dest_datafile, debug_data, delimiter=delimiter, fmt=fmt)
|
||||
|
||||
|
||||
class BaseDataset(DatasetABC):
|
||||
"""Base Class for Dataset"""
|
||||
class ArrayDataset(Dataset):
|
||||
"""Dataset that is based on x and y arrays (e.g., numpy/pandas)"""
|
||||
|
||||
def __init__(self, x: DATA_ARRAY_TYPE, y: DATA_ARRAY_TYPE, **kwargs):
|
||||
"""
|
||||
BaseDataset constructor.
|
||||
ArrayDataset constructor.
|
||||
:param x: collection of data samples
|
||||
:param y: collection of labels
|
||||
:param kwargs: dataset parameters
|
||||
|
|
@ -159,7 +159,7 @@ class DatasetFactory:
|
|||
:param name: dataset name
|
||||
:return:
|
||||
"""
|
||||
def inner_wrapper(wrapped_class: DatasetABC) -> Any:
|
||||
def inner_wrapper(wrapped_class: Dataset) -> Any:
|
||||
if name in cls.registry:
|
||||
logger.warning('Dataset %s already exists. Will replace it', name)
|
||||
cls.registry[name] = wrapped_class
|
||||
|
|
@ -168,7 +168,7 @@ class DatasetFactory:
|
|||
return inner_wrapper
|
||||
|
||||
@classmethod
|
||||
def create_dataset(cls, name: str, **kwargs) -> DatasetABC:
|
||||
def create_dataset(cls, name: str, **kwargs) -> Dataset:
|
||||
"""
|
||||
Factory command to create dataset instance.
|
||||
This method gets the appropriate Dataset class from the registry
|
||||
|
|
@ -190,7 +190,7 @@ class DatasetFactory:
|
|||
|
||||
class Data:
|
||||
|
||||
def __init__(self, train: DatasetABC = None, test: DatasetABC = None, **kwargs):
|
||||
def __init__(self, train: Dataset = None, test: Dataset = None, **kwargs):
|
||||
"""
|
||||
Data class constructor.
|
||||
The class stores train and test datasets.
|
||||
|
|
@ -205,11 +205,11 @@ class Data:
|
|||
self.train = DatasetFactory.create_dataset(train=True, **kwargs)
|
||||
self.test = DatasetFactory.create_dataset(train=False, **kwargs)
|
||||
|
||||
def get_train_set(self) -> DatasetABC:
|
||||
def get_train_set(self) -> Dataset:
|
||||
"""Return train DatasetBase"""
|
||||
return self.train
|
||||
|
||||
def get_test_set(self) -> DatasetABC:
|
||||
def get_test_set(self) -> Dataset:
|
||||
"""Return test DatasetBase"""
|
||||
return self.test
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue