remove self from array2numpy and array2torch_tensor functions

This commit is contained in:
Ron Shmelkin 2022-07-24 15:32:09 +03:00
parent c77e34e373
commit 15d7008224
No known key found for this signature in database
GPG key ID: A4289A6607B5C294

View file

@ -24,7 +24,7 @@ OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
"""
converts from INPUT_DATA_ARRAY_TYPE to numpy array
@ -32,7 +32,6 @@ def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
if type(arr) == np.ndarray:
return arr
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
self.is_pandas = True
return arr.to_numpy()
if isinstance(arr, list):
return np.array(arr)
@ -42,14 +41,13 @@ def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
raise ValueError("Non supported type: ", type(arr).__name__)
def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
"""
converts from INPUT_DATA_ARRAY_TYPE to torch tensor array
"""
if type(arr) == np.ndarray:
return torch.from_numpy(arr)
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
self.is_pandas = True
return torch.from_numpy(arr.to_numpy())
if isinstance(arr, list):
return torch.tensor(arr)
@ -178,10 +176,12 @@ class ArrayDataset(Dataset):
:param feature_names: list of str, The feature names, in the order that they appear in the data (optional)
:param kwargs: dataset parameters
"""
self.is_pandas = False
self.is_pandas = self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series
self.features_names = features_names
self._y = array2numpy(self, y) if y is not None else None
self._x = array2numpy(self, x)
self._y = array2numpy(y) if y is not None else None
self._x = array2numpy(x)
if self.is_pandas:
if features_names and not np.array_equal(features_names, x.columns):
raise ValueError("The supplied features are not the same as in the data features")
@ -207,9 +207,11 @@ class PytorchData(Dataset):
:param y: collection of labels (optional)
:param kwargs: dataset parameters
"""
self.is_pandas = False
self._y = array2torch_tensor(self, y) if y is not None else None
self._x = array2torch_tensor(self, x)
self._y = array2torch_tensor(y) if y is not None else None
self._x = array2torch_tensor(x)
self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series
if self.is_pandas:
self.features_names = x.columns
@ -223,11 +225,11 @@ class PytorchData(Dataset):
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return array2numpy(self, self._x)
return array2numpy(self._x)
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return labels as numpy array"""
return array2numpy(self, self._y) if self._y is not None else None
return array2numpy(self._y) if self._y is not None else None
def get_sample_item(self, idx) -> Tensor:
return self._x[idx]