This commit is contained in:
abigailt 2022-03-07 19:09:31 +02:00
parent 3d82db80c4
commit f2df2fcc8c
6 changed files with 35 additions and 43 deletions

View file

@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame]
class DatasetABC(metaclass=ABCMeta):
class Dataset(metaclass=ABCMeta):
"""Base Abstract Class for Dataset"""
@abstractmethod
@ -38,7 +38,7 @@ class DatasetABC(metaclass=ABCMeta):
pass
class StoredDatasetABC(DatasetABC):
class StoredDataset(Dataset):
"""Abstract Class for Storable Dataset"""
@abstractmethod
@ -73,7 +73,7 @@ class StoredDatasetABC(DatasetABC):
logger.info('Dataset Downloaded')
if unzip:
StoredDatasetABC.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
@staticmethod
@ -123,12 +123,12 @@ class StoredDatasetABC(DatasetABC):
np.savetxt(dest_datafile, debug_data, delimiter=delimiter, fmt=fmt)
class BaseDataset(DatasetABC):
"""Base Class for Dataset"""
class ArrayDataset(Dataset):
"""Dataset that is based on x and y arrays (e.g., numpy/pandas)"""
def __init__(self, x: DATA_ARRAY_TYPE, y: DATA_ARRAY_TYPE, **kwargs):
"""
BaseDataset constructor.
ArrayDataset constructor.
:param x: collection of data samples
:param y: collection of labels
:param kwargs: dataset parameters
@ -159,7 +159,7 @@ class DatasetFactory:
:param name: dataset name
:return:
"""
def inner_wrapper(wrapped_class: DatasetABC) -> Any:
def inner_wrapper(wrapped_class: Dataset) -> Any:
if name in cls.registry:
logger.warning('Dataset %s already exists. Will replace it', name)
cls.registry[name] = wrapped_class
@ -168,7 +168,7 @@ class DatasetFactory:
return inner_wrapper
@classmethod
def create_dataset(cls, name: str, **kwargs) -> DatasetABC:
def create_dataset(cls, name: str, **kwargs) -> Dataset:
"""
Factory command to create dataset instance.
This method gets the appropriate Dataset class from the registry
@ -190,7 +190,7 @@ class DatasetFactory:
class Data:
def __init__(self, train: DatasetABC = None, test: DatasetABC = None, **kwargs):
def __init__(self, train: Dataset = None, test: Dataset = None, **kwargs):
"""
Data class constructor.
The class stores train and test datasets.
@ -205,11 +205,11 @@ class Data:
self.train = DatasetFactory.create_dataset(train=True, **kwargs)
self.test = DatasetFactory.create_dataset(train=False, **kwargs)
def get_train_set(self) -> DatasetABC:
def get_train_set(self) -> Dataset:
"""Return train DatasetBase"""
return self.train
def get_test_set(self) -> DatasetABC:
def get_test_set(self) -> Dataset:
"""Return test DatasetBase"""
return self.test