formatting

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-08-10 13:20:28 +03:00
parent f85fc87bdd
commit 69e45d99e5
4 changed files with 17 additions and 16 deletions

View file

@ -20,7 +20,7 @@ from scipy.sparse import csr_matrix
logger = logging.getLogger(__name__)
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor, csr_matrix]
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
@ -30,15 +30,15 @@ def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
"""
converts from INPUT_DATA_ARRAY_TYPE to numpy array
"""
if type(arr) == np.ndarray:
if isinstance(arr, np.ndarray):
return arr
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
if isinstance(arr, pd.DataFrame) or isinstance(arr, pd.Series):
return arr.to_numpy()
if isinstance(arr, list):
return np.array(arr)
if type(arr) == Tensor:
if isinstance(arr, Tensor):
return arr.detach().cpu().numpy()
if type(arr) == csr_matrix:
if isinstance(arr, csr_matrix):
return arr.toarray()
raise ValueError("Non supported type: ", type(arr).__name__)
@ -48,15 +48,15 @@ def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
"""
converts from INPUT_DATA_ARRAY_TYPE to torch tensor array
"""
if type(arr) == np.ndarray:
if isinstance(arr, np.ndarray):
return torch.from_numpy(arr)
if type(arr) == pd.DataFrame or type(arr) == pd.Series:
if isinstance(arr, pd.DataFrame) or isinstance(arr, pd.Series):
return torch.from_numpy(arr.to_numpy())
if isinstance(arr, list):
return torch.tensor(arr)
if type(arr) == Tensor:
if isinstance(arr, Tensor):
return arr
if type(arr) == csr_matrix:
if isinstance(arr, csr_matrix):
return torch.from_numpy(arr.toarray())
raise ValueError("Non supported type: ", type(arr).__name__)
@ -222,7 +222,7 @@ class ArrayDataset(Dataset):
def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None,
features_names: Optional[list] = None, **kwargs):
self.is_pandas = self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series
self.is_pandas = self.is_pandas = isinstance(x, pd.DataFrame) or isinstance(x, pd.Series)
self.features_names = features_names
self._y = array2numpy(y) if y is not None else None
@ -330,7 +330,7 @@ class PytorchData(Dataset):
self._y = array2torch_tensor(y) if y is not None else None
self._x = array2torch_tensor(x)
self.is_pandas = type(x) == pd.DataFrame or type(x) == pd.Series
self.is_pandas = isinstance(x, pd.DataFrame) or isinstance(x, pd.Series)
if self.is_pandas:
self.features_names = x.columns