Add more to wrappers

This commit is contained in:
abigailt 2022-03-15 11:42:57 +02:00
parent f2df2fcc8c
commit 45cc9180b8
6 changed files with 74 additions and 30 deletions

View file

@ -5,7 +5,7 @@ Implementation of utility classes for dataset handling
"""
from abc import ABCMeta, abstractmethod
from typing import Callable, Collection, Any, Union
from typing import Callable, Collection, Any, Union, List, Optional
import tarfile
import os
@ -13,11 +13,14 @@ import urllib.request
import numpy as np
import pandas as pd
import logging
from torch import Tensor
logger = logging.getLogger(__name__)
DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame]
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
class Dataset(metaclass=ABCMeta):
@ -124,28 +127,50 @@ class StoredDataset(Dataset):
class ArrayDataset(Dataset):
"""Dataset that is based on x and y arrays (e.g., numpy/pandas)"""
"""Dataset that is based on x and y arrays (e.g., numpy/pandas/list...)"""
def __init__(self, x: DATA_ARRAY_TYPE, y: DATA_ARRAY_TYPE, **kwargs):
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
"""
ArrayDataset constructor.
:param x: collection of data samples
:param y: collection of labels
:param y: collection of labels (optional)
:param kwargs: dataset parameters
"""
self.x = x
self.y = y
# convert to numpy
if type(x) == np.ndarray:
self._x = x
elif type(x) == pd.DataFrame:
self._x = x.to_numpy()
elif isinstance(x, list):
self._x = np.array(x)
elif type(x) == Tensor:
self._x = x.numpy()
else:
raise ValueError('Non supported type for x: ', type(x).__name__)
if len(self.x) != len(self.y):
self._y = None
if y is not None:
if type(y) == np.ndarray:
self._y = y
elif type(y) == pd.DataFrame:
self._y = y.to_numpy()
elif isinstance(y, list):
self._y = np.array(y)
elif type(y) == Tensor:
self._y = y.numpy()
else:
raise ValueError('Non supported type for y: ', type(y).__name__)
if y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')
def get_samples(self) -> DATA_ARRAY_TYPE:
"""Return data samples"""
return self.x
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return self._x
def get_labels(self) -> DATA_ARRAY_TYPE:
"""Return labels"""
return self.y
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return labels as numpy array"""
return self._y
class DatasetFactory:
@ -189,7 +214,6 @@ class DatasetFactory:
class Data:
def __init__(self, train: Dataset = None, test: Dataset = None, **kwargs):
"""
Data class constructor.