mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-28 22:36:22 +02:00
fix notebook and add features_names to ArrayDataset
and allow providing features names in QI and Cat features not just indexes
This commit is contained in:
parent
137167fb0c
commit
66c86dc595
5 changed files with 89 additions and 74 deletions
|
|
@ -18,7 +18,6 @@ from torch import Tensor
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
|
||||
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
|
||||
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
|
||||
|
|
@ -113,7 +112,6 @@ class StoredDataset(Dataset):
|
|||
if unzip:
|
||||
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_archive(zip_path: str, dest_path=None, remove_archive=False):
|
||||
"""
|
||||
|
|
@ -164,7 +162,8 @@ class StoredDataset(Dataset):
|
|||
class ArrayDataset(Dataset):
|
||||
"""Dataset that is based on x and y arrays (e.g., numpy/pandas/list...)"""
|
||||
|
||||
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
|
||||
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names=None,
|
||||
**kwargs):
|
||||
"""
|
||||
ArrayDataset constructor.
|
||||
:param x: collection of data samples
|
||||
|
|
@ -172,10 +171,12 @@ class ArrayDataset(Dataset):
|
|||
:param kwargs: dataset parameters
|
||||
"""
|
||||
self.is_pandas = False
|
||||
self.features_names = None
|
||||
self.features_names = features_names
|
||||
self._y = array2numpy(self, y) if y is not None else None
|
||||
self._x = array2numpy(self, x)
|
||||
if self.is_pandas:
|
||||
if features_names and not np.array_equal(features_names, x.columns):
|
||||
raise ValueError("The supplied features are not the same as in the data features")
|
||||
self.features_names = x.columns
|
||||
|
||||
if y is not None and len(self._x) != len(self._y):
|
||||
|
|
@ -213,7 +214,6 @@ class PytorchData(Dataset):
|
|||
else:
|
||||
self.__getitem__ = self.get_sample_item
|
||||
|
||||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return data samples as numpy array"""
|
||||
return array2numpy(self._x)
|
||||
|
|
@ -244,6 +244,7 @@ class DatasetFactory:
|
|||
:param name: dataset name
|
||||
:return:
|
||||
"""
|
||||
|
||||
def inner_wrapper(wrapped_class: Dataset) -> Any:
|
||||
if name in cls.registry:
|
||||
logger.warning('Dataset %s already exists. Will replace it', name)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue