mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-28 22:36:22 +02:00
Train just on qi (#15)
* QI updates * update code to support training ML on QI features * fix code so features that are not from QI should not be part of generalizations and add description * merging two branches, training on QI and on all data * adding tests and asserts
This commit is contained in:
parent
2eb626c00c
commit
a9a93c8a3a
4 changed files with 373 additions and 135 deletions
77
apt/utils.py
77
apt/utils.py
|
|
@ -6,15 +6,15 @@ from os import path, mkdir
|
|||
from six.moves.urllib.request import urlretrieve
|
||||
|
||||
|
||||
def _load_iris(test_set_size: float=0.3):
|
||||
def _load_iris(test_set_size: float = 0.3):
|
||||
iris = datasets.load_iris()
|
||||
data = iris.data
|
||||
labels = iris.target
|
||||
|
||||
# Split training and test sets
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(data, labels, test_size=test_set_size,
|
||||
random_state=18, stratify=labels,
|
||||
shuffle=True)
|
||||
random_state=18, stratify=labels,
|
||||
shuffle=True)
|
||||
|
||||
return (x_train, y_train), (x_test, y_test)
|
||||
|
||||
|
|
@ -29,6 +29,77 @@ def get_iris_dataset():
|
|||
return _load_iris()
|
||||
|
||||
|
||||
def get_german_credit_dataset(test_set: float = 0.3):
|
||||
"""
|
||||
Loads the UCI German_credit dataset from `tests/datasets/german` or downloads it if necessary.
|
||||
|
||||
:return: Dataset and labels as pandas dataframes.
|
||||
"""
|
||||
|
||||
url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/german.data'
|
||||
data_dir = '../datasets/german'
|
||||
data_file = '../datasets/german/data'
|
||||
|
||||
if not path.exists(data_dir):
|
||||
mkdir(data_dir)
|
||||
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
if not path.exists(data_file):
|
||||
urlretrieve(url, data_file)
|
||||
|
||||
# load data
|
||||
features = ["Existing_checking_account", "Duration_in_month", "Credit_history", "Purpose", "Credit_amount",
|
||||
"Savings_account", "Present_employment_since", "Installment_rate", "Personal_status_sex", "debtors",
|
||||
"Present_residence", "Property", "Age", "Other_installment_plans", "Housing",
|
||||
"Number_of_existing_credits", "Job", "N_people_being_liable_provide_maintenance", "Telephone",
|
||||
"Foreign_worker", "label"]
|
||||
data = pd.read_csv(data_file, sep=" ", names=features, engine="python")
|
||||
# remove rows with missing label
|
||||
data = data.dropna(subset=["label"])
|
||||
_modify_german_dataset(data)
|
||||
|
||||
# Split training and test sets
|
||||
stratified = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=test_set, random_state=18)
|
||||
for train_set, test_set in stratified.split(data, data["label"]):
|
||||
train = data.iloc[train_set]
|
||||
test = data.iloc[test_set]
|
||||
x_train = train.drop(["label"], axis=1)
|
||||
y_train = train.loc[:, "label"]
|
||||
x_test = test.drop(["label"], axis=1)
|
||||
y_test = test.loc[:, "label"]
|
||||
|
||||
categorical_features = ["Existing_checking_account", "Credit_history", "Purpose", "Savings_account",
|
||||
"Present_employment_since", "Personal_status_sex", "debtors", "Property",
|
||||
"Other_installment_plans", "Housing", "Job"]
|
||||
x_train.reset_index(drop=True, inplace=True)
|
||||
y_train.reset_index(drop=True, inplace=True)
|
||||
x_test.reset_index(drop=True, inplace=True)
|
||||
y_test.reset_index(drop=True, inplace=True)
|
||||
|
||||
return (x_train, y_train), (x_test, y_test)
|
||||
|
||||
|
||||
def _modify_german_dataset(data):
|
||||
|
||||
def modify_Foreign_worker(value):
|
||||
if value == 'A201':
|
||||
return 1
|
||||
elif value == 'A202':
|
||||
return 0
|
||||
else:
|
||||
raise Exception('Bad value')
|
||||
|
||||
def modify_Telephone(value):
|
||||
if value == 'A191':
|
||||
return 0
|
||||
elif value == 'A192':
|
||||
return 1
|
||||
else:
|
||||
raise Exception('Bad value')
|
||||
data['Foreign_worker'] = data['Foreign_worker'].apply(modify_Foreign_worker)
|
||||
data['Telephone'] = data['Telephone'].apply(modify_Telephone)
|
||||
|
||||
|
||||
def get_adult_dataset():
|
||||
"""
|
||||
Loads the UCI Adult dataset from `tests/datasets/adult` or downloads it if necessary.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue