Add more to wrappers

This commit is contained in:
abigailt 2022-03-15 11:42:57 +02:00 committed by olasaadi
parent a2b560920f
commit a9162fbc43
6 changed files with 86 additions and 38 deletions

View file

@ -5,7 +5,7 @@ Implementation of utility classes for dataset handling
"""
from abc import ABCMeta, abstractmethod
from typing import Callable, Collection, Any, Union
from typing import Callable, Collection, Any, Union, List, Optional
import tarfile
import os
@ -13,11 +13,14 @@ import urllib.request
import numpy as np
import pandas as pd
import logging
from torch import Tensor
logger = logging.getLogger(__name__)
DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame]
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
class Dataset(metaclass=ABCMeta):
@ -124,28 +127,50 @@ class StoredDataset(Dataset):
class ArrayDataset(Dataset):
"""Dataset that is based on x and y arrays (e.g., numpy/pandas)"""
"""Dataset that is based on x and y arrays (e.g., numpy/pandas/list...)"""
def __init__(self, x: DATA_ARRAY_TYPE, y: DATA_ARRAY_TYPE, **kwargs):
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
"""
ArrayDataset constructor.
:param x: collection of data samples
:param y: collection of labels
:param y: collection of labels (optional)
:param kwargs: dataset parameters
"""
self.x = x
self.y = y
# convert to numpy
if type(x) == np.ndarray:
self._x = x
elif type(x) == pd.DataFrame:
self._x = x.to_numpy()
elif isinstance(x, list):
self._x = np.array(x)
elif type(x) == Tensor:
self._x = x.numpy()
else:
raise ValueError('Non supported type for x: ', type(x).__name__)
if len(self.x) != len(self.y):
self._y = None
if y is not None:
if type(y) == np.ndarray:
self._y = y
elif type(y) == pd.DataFrame:
self._y = y.to_numpy()
elif isinstance(y, list):
self._y = np.array(y)
elif type(y) == Tensor:
self._y = y.numpy()
else:
raise ValueError('Non supported type for y: ', type(y).__name__)
if y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')
def get_samples(self) -> DATA_ARRAY_TYPE:
"""Return data samples"""
return self.x
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return self._x
def get_labels(self) -> DATA_ARRAY_TYPE:
"""Return labels"""
return self.y
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return labels as numpy array"""
return self._y
class DatasetFactory:
@ -189,7 +214,6 @@ class DatasetFactory:
class Data:
def __init__(self, train: Dataset = None, test: Dataset = None, **kwargs):
"""
Data class constructor.