fix notebook and add features_names to ArrayDataset

and allow providing features names in QI and Cat features not just indexes
This commit is contained in:
olasaadi 2022-03-24 19:32:24 +02:00
parent 137167fb0c
commit 66c86dc595
5 changed files with 89 additions and 74 deletions

View file

@ -18,7 +18,6 @@ from torch import Tensor
logger = logging.getLogger(__name__)
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
@ -113,7 +112,6 @@ class StoredDataset(Dataset):
if unzip:
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
@staticmethod
def extract_archive(zip_path: str, dest_path=None, remove_archive=False):
"""
@ -164,7 +162,8 @@ class StoredDataset(Dataset):
class ArrayDataset(Dataset):
"""Dataset that is based on x and y arrays (e.g., numpy/pandas/list...)"""
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names=None,
**kwargs):
"""
ArrayDataset constructor.
:param x: collection of data samples
@ -172,10 +171,12 @@ class ArrayDataset(Dataset):
:param kwargs: dataset parameters
"""
self.is_pandas = False
self.features_names = None
self.features_names = features_names
self._y = array2numpy(self, y) if y is not None else None
self._x = array2numpy(self, x)
if self.is_pandas:
if features_names and not np.array_equal(features_names, x.columns):
raise ValueError("The supplied features are not the same as in the data features")
self.features_names = x.columns
if y is not None and len(self._x) != len(self._y):
@ -213,7 +214,6 @@ class PytorchData(Dataset):
else:
self.__getitem__ = self.get_sample_item
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return array2numpy(self._x)
@ -244,6 +244,7 @@ class DatasetFactory:
:param name: dataset name
:return:
"""
def inner_wrapper(wrapped_class: Dataset) -> Any:
if name in cls.registry:
logger.warning('Dataset %s already exists. Will replace it', name)