mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-29 15:59:38 +02:00
Externalize common test code to methods.
Support for sparse matrix. Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
3de93a87f1
commit
f85fc87bdd
3 changed files with 122 additions and 323 deletions
|
|
@ -15,6 +15,7 @@ import pandas as pd
|
|||
import logging
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -37,6 +38,8 @@ def array2numpy(arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
|
|||
return np.array(arr)
|
||||
if type(arr) == Tensor:
|
||||
return arr.detach().cpu().numpy()
|
||||
if type(arr) == csr_matrix:
|
||||
return arr.toarray()
|
||||
|
||||
raise ValueError("Non supported type: ", type(arr).__name__)
|
||||
|
||||
|
|
@ -53,6 +56,8 @@ def array2torch_tensor(arr: INPUT_DATA_ARRAY_TYPE) -> Tensor:
|
|||
return torch.tensor(arr)
|
||||
if type(arr) == Tensor:
|
||||
return arr
|
||||
if type(arr) == csr_matrix:
|
||||
return torch.from_numpy(arr.toarray())
|
||||
|
||||
raise ValueError("Non supported type: ", type(arr).__name__)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue