Externalize common test code to methods.

Support for sparse matrix.

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-08-10 12:57:59 +03:00
parent 3de93a87f1
commit f85fc87bdd
3 changed files with 122 additions and 323 deletions

View file

@ -15,6 +15,7 @@ import pandas as pd
import logging
import torch
from torch import Tensor
from scipy.sparse import csr_matrix
logger = logging.getLogger(__name__)
@ -37,6 +38,8 @@ def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
return np.array(arr)
if type(arr) == Tensor:
return arr.detach().cpu().numpy()
if type(arr) == csr_matrix:
return arr.toarray()
raise ValueError("Non supported type: ", type(arr).__name__)
@ -53,6 +56,8 @@ def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
return torch.tensor(arr)
if type(arr) == Tensor:
return arr
if type(arr) == csr_matrix:
return torch.from_numpy(arr.toarray())
raise ValueError("Non supported type: ", type(arr).__name__)