mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-25 04:46:21 +02:00
Merge branch 'wrappers' into dataset_wrapper_anonimizer
This commit is contained in:
commit
5f6a258f8f
2 changed files with 6 additions and 2 deletions
|
|
@ -18,12 +18,14 @@ from torch import Tensor
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
|
||||
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
|
||||
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
|
||||
|
||||
|
||||
def array2numpy(self, arr: INPUT_DATA_ARRAY_TYPE) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
|
||||
"""
|
||||
converts from INPUT_DATA_ARRAY_TYPE to numpy array
|
||||
"""
|
||||
|
|
@ -210,11 +212,13 @@ class PytorchData(Dataset):
|
|||
if y is not None and len(self._x) != len(self._y):
|
||||
raise ValueError('Non equivalent lengths of x and y')
|
||||
|
||||
|
||||
if self._y is not None:
|
||||
self.__getitem__ = self.get_item
|
||||
else:
|
||||
self.__getitem__ = self.get_sample_item
|
||||
|
||||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return data samples as numpy array"""
|
||||
return array2numpy(self._x)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from apt.utils.models import SklearnClassifier, SklearnRegressor
|
||||
from apt.utils.models import SklearnClassifier, SklearnRegressor, ModelOutputType
|
||||
from apt.utils.datasets import ArrayDataset
|
||||
from apt.utils import dataset_utils
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ from sklearn.ensemble import RandomForestClassifier
|
|||
def test_sklearn_classifier():
|
||||
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset()
|
||||
underlying_model = RandomForestClassifier()
|
||||
model = SklearnClassifier(underlying_model)
|
||||
model = SklearnClassifier(underlying_model, ModelOutputType.CLASSIFIER_VECTOR)
|
||||
train = ArrayDataset(x_train, y_train)
|
||||
test = ArrayDataset(x_test, y_test)
|
||||
model.fit(train)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue