mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-28 22:36:22 +02:00
apply changes after rebase with wrappers
This commit is contained in:
parent
6afb175d6f
commit
b4eddabd37
4 changed files with 32 additions and 36 deletions
|
|
@ -5,7 +5,7 @@ import ssl
|
|||
from os import path, mkdir
|
||||
from six.moves.urllib.request import urlretrieve
|
||||
|
||||
from apt.utils.datasets import BaseDataset, Data
|
||||
from apt.utils.datasets import ArrayDataset, Data
|
||||
|
||||
|
||||
def _load_iris(test_set_size: float = 0.3):
|
||||
|
|
@ -16,8 +16,8 @@ def _load_iris(test_set_size: float = 0.3):
|
|||
# Split training and test sets
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(data, labels, test_size=test_set_size,
|
||||
random_state=18, stratify=labels)
|
||||
train_dataset = BaseDataset(x_train, y_train)
|
||||
test_dataset = BaseDataset(x_test, y_test)
|
||||
train_dataset = ArrayDataset(x_train, y_train)
|
||||
test_dataset = ArrayDataset(x_test, y_test)
|
||||
dataset = Data(train_dataset, test_dataset)
|
||||
return dataset
|
||||
|
||||
|
|
@ -41,8 +41,8 @@ def _load_diabetes(test_set_size: float = 0.3):
|
|||
x_train, x_test, y_train, y_test = model_selection.train_test_split(data, labels, test_size=test_set_size,
|
||||
random_state=18)
|
||||
|
||||
train_dataset = BaseDataset(x_train, y_train)
|
||||
test_dataset = BaseDataset(x_test, y_test)
|
||||
train_dataset = ArrayDataset(x_train, y_train)
|
||||
test_dataset = ArrayDataset(x_test, y_test)
|
||||
dataset = Data(train_dataset, test_dataset)
|
||||
return dataset
|
||||
|
||||
|
|
@ -104,8 +104,8 @@ def get_german_credit_dataset(test_set: float = 0.3):
|
|||
x_test.reset_index(drop=True, inplace=True)
|
||||
y_test.reset_index(drop=True, inplace=True)
|
||||
|
||||
train_dataset = BaseDataset(x_train, y_train)
|
||||
test_dataset = BaseDataset(x_test, y_test)
|
||||
train_dataset = ArrayDataset(x_train, y_train)
|
||||
test_dataset = ArrayDataset(x_test, y_test)
|
||||
dataset = Data(train_dataset, test_dataset)
|
||||
return dataset
|
||||
|
||||
|
|
@ -166,8 +166,8 @@ def get_adult_dataset():
|
|||
y_train = train.loc[:, 'label']
|
||||
x_test = test.drop(['label'], axis=1)
|
||||
y_test = test.loc[:, 'label']
|
||||
train_dataset = BaseDataset(x_train, y_train)
|
||||
test_dataset = BaseDataset(x_test, y_test)
|
||||
train_dataset = ArrayDataset(x_train, y_train)
|
||||
test_dataset = ArrayDataset(x_test, y_test)
|
||||
dataset = Data(train_dataset, test_dataset)
|
||||
return dataset
|
||||
|
||||
|
|
@ -330,7 +330,7 @@ def get_nursery_dataset(raw: bool = True, test_set: float = 0.2, transform_socia
|
|||
x_train = x_train.astype(str)
|
||||
x_test = x_test.astype(str)
|
||||
|
||||
train_dataset = BaseDataset(x_train, y_train)
|
||||
test_dataset = BaseDataset(x_test, y_test)
|
||||
train_dataset = ArrayDataset(x_train, y_train)
|
||||
test_dataset = ArrayDataset(x_test, y_test)
|
||||
dataset = Data(train_dataset, test_dataset)
|
||||
return dataset
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue