diff --git a/apt/utils/datasets/datasets.py b/apt/utils/datasets/datasets.py index 26ffdef..72efd5a 100644 --- a/apt/utils/datasets/datasets.py +++ b/apt/utils/datasets/datasets.py @@ -24,7 +24,7 @@ OUTPUT_DATA_ARRAY_TYPE = np.ndarray DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame] -def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE: +def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE: """ converts from INPUT_DATA_ARRAY_TYPE to numpy array @@ -32,7 +32,6 @@ def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE: if type(arr) == np.ndarray: return arr if type(arr) == pd.DataFrame or type(arr) == pd.Series: - self.is_pandas = True return arr.to_numpy() if isinstance(arr, list): return np.array(arr) @@ -42,14 +41,13 @@ def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE: raise ValueError("Non supported type: ", type(arr).__name__) -def array2torch_tensor(self, arr: INPUT_DATA_ARRAY_TYPE) -> Tensor: +def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor: """ converts from INPUT_DATA_ARRAY_TYPE to torch tensor array """ if type(arr) == np.ndarray: return torch.from_numpy(arr) if type(arr) == pd.DataFrame or type(arr) == pd.Series: - self.is_pandas = True return torch.from_numpy(arr.to_numpy()) if isinstance(arr, list): return torch.tensor(arr) @@ -178,10 +176,12 @@ class ArrayDataset(Dataset): :param feature_names: list of str, The feature names, in the order that they appear in the data (optional) :param kwargs: dataset parameters """ - self.is_pandas = False + self.is_pandas = self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series + self.features_names = features_names - self._y = array2numpy(self, y) if y is not None else None - self._x = array2numpy(self, x) + self._y = array2numpy(y) if y is not None else None + self._x = array2numpy(x) + if self.is_pandas: if features_names and not np.array_equal(features_names, x.columns): raise ValueError("The supplied features are not the same as in the data features") @@ -207,9 +207,11 @@ class PytorchData(Dataset): :param y: collection of labels (optional) :param kwargs: dataset parameters """ - self.is_pandas = False - self._y = array2torch_tensor(self, y) if y is not None else None - self._x = array2torch_tensor(self, x) + self._y = array2torch_tensor(y) if y is not None else None + self._x = array2torch_tensor(x) + + self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series + if self.is_pandas: self.features_names = x.columns @@ -223,11 +225,11 @@ class PytorchData(Dataset): def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE: """Return data samples as numpy array""" - return array2numpy(self, self._x) + return array2numpy(self._x) def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE: """Return labels as numpy array""" - return array2numpy(self, self._y) if self._y is not None else None + return array2numpy(self._y) if self._y is not None else None def get_sample_item(self, idx) -> Tensor: return self._x[idx]