using dataset wrapper on anonymizer

This commit is contained in:
olasaadi 2022-03-01 02:28:41 +02:00
parent a432b8f5f9
commit 7b788b9018
2 changed files with 42 additions and 59 deletions

View file

@ -5,8 +5,6 @@ import ssl
from os import path, mkdir
from six.moves.urllib.request import urlretrieve
from apt.utils.datasets import ArrayDataset, Data
def _load_iris(test_set_size: float = 0.3):
iris = datasets.load_iris()
@ -16,10 +14,8 @@ def _load_iris(test_set_size: float = 0.3):
# Split training and test sets
x_train, x_test, y_train, y_test = model_selection.train_test_split(data, labels, test_size=test_set_size,
random_state=18, stratify=labels)
train_dataset = ArrayDataset(x_train, y_train)
test_dataset = ArrayDataset(x_test, y_test)
dataset = Data(train_dataset, test_dataset)
return dataset
return (x_train, y_train), (x_test, y_test)
def get_iris_dataset():
@ -41,10 +37,7 @@ def _load_diabetes(test_set_size: float = 0.3):
x_train, x_test, y_train, y_test = model_selection.train_test_split(data, labels, test_size=test_set_size,
random_state=18)
train_dataset = ArrayDataset(x_train, y_train)
test_dataset = ArrayDataset(x_test, y_test)
dataset = Data(train_dataset, test_dataset)
return dataset
return (x_train, y_train), (x_test, y_test)
def get_diabetes_dataset():
@ -104,10 +97,7 @@ def get_german_credit_dataset(test_set: float = 0.3):
x_test.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)
train_dataset = ArrayDataset(x_train, y_train)
test_dataset = ArrayDataset(x_test, y_test)
dataset = Data(train_dataset, test_dataset)
return dataset
return (x_train, y_train), (x_test, y_test)
def _modify_german_dataset(data):
@ -166,10 +156,8 @@ def get_adult_dataset():
y_train = train.loc[:, 'label']
x_test = test.drop(['label'], axis=1)
y_test = test.loc[:, 'label']
train_dataset = ArrayDataset(x_train, y_train)
test_dataset = ArrayDataset(x_test, y_test)
dataset = Data(train_dataset, test_dataset)
return dataset
return (x_train, y_train), (x_test, y_test)
def _modify_adult_dataset(data):
@ -327,10 +315,5 @@ def get_nursery_dataset(raw: bool = True, test_set: float = 0.2, transform_socia
y_train = train.loc[:, "label"]
x_test = test.drop(["label"], axis=1)
y_test = test.loc[:, "label"]
x_train = x_train.astype(str)
x_test = x_test.astype(str)
train_dataset = ArrayDataset(x_train, y_train)
test_dataset = ArrayDataset(x_test, y_test)
dataset = Data(train_dataset, test_dataset)
return dataset
return (x_train, y_train), (x_test, y_test)