This commit is contained in:
olasaadi 2022-03-23 17:54:37 +02:00
parent 312469212e
commit 06158c8508
2 changed files with 12 additions and 8 deletions

View file

@ -41,13 +41,14 @@ def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
raise ValueError('Non supported type: ', type(arr).__name__)
def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
"""
converts from INPUT_DATA_ARRAY_TYPE to torch tensor array
"""
if type(arr) == np.ndarray:
return torch.from_numpy(arr)
if type(arr) == pd.DataFrame:
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
self.is_pandas = True
return torch.from_numpy(arr.to_numpy())
if isinstance(arr, list):
return torch.tensor(arr)
@ -198,8 +199,11 @@ class PytorchData(Dataset):
:param y: collection of labels (optional)
:param kwargs: dataset parameters
"""
self._x = array2torch_tensor(x)
self._y = array2torch_tensor(y) if y is not None else None
self.is_pandas = False
self._y = array2torch_tensor(self, y) if y is not None else None
self._x = array2torch_tensor(self, x)
if self.is_pandas:
self.features_names = x.columns
if y is not None and len(self._x) != len(self._y):
raise ValueError('Non equivalent lengths of x and y')