Train just on qi (#15)

* QI updates
* update code to support training ML on QI features
* fix code so features that are not from QI should not be part of generalizations
and add description
* merging two branches, training on QI and on all data
* adding tests and asserts
This commit is contained in:
olasaadi 2022-01-12 17:01:27 +02:00 committed by GitHub
parent 2eb626c00c
commit a9a93c8a3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 373 additions and 135 deletions

View file

@ -6,15 +6,15 @@ from os import path, mkdir
from six.moves.urllib.request import urlretrieve
def _load_iris(test_set_size: float=0.3):
def _load_iris(test_set_size: float = 0.3):
iris = datasets.load_iris()
data = iris.data
labels = iris.target
# Split training and test sets
x_train, x_test, y_train, y_test = model_selection.train_test_split(data, labels, test_size=test_set_size,
random_state=18, stratify=labels,
shuffle=True)
random_state=18, stratify=labels,
shuffle=True)
return (x_train, y_train), (x_test, y_test)
@ -29,6 +29,77 @@ def get_iris_dataset():
return _load_iris()
def get_german_credit_dataset(test_set: float = 0.3):
"""
Loads the UCI German_credit dataset from `tests/datasets/german` or downloads it if necessary.
:return: Dataset and labels as pandas dataframes.
"""
url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/german.data'
data_dir = '../datasets/german'
data_file = '../datasets/german/data'
if not path.exists(data_dir):
mkdir(data_dir)
ssl._create_default_https_context = ssl._create_unverified_context
if not path.exists(data_file):
urlretrieve(url, data_file)
# load data
features = ["Existing_checking_account", "Duration_in_month", "Credit_history", "Purpose", "Credit_amount",
"Savings_account", "Present_employment_since", "Installment_rate", "Personal_status_sex", "debtors",
"Present_residence", "Property", "Age", "Other_installment_plans", "Housing",
"Number_of_existing_credits", "Job", "N_people_being_liable_provide_maintenance", "Telephone",
"Foreign_worker", "label"]
data = pd.read_csv(data_file, sep=" ", names=features, engine="python")
# remove rows with missing label
data = data.dropna(subset=["label"])
_modify_german_dataset(data)
# Split training and test sets
stratified = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=test_set, random_state=18)
for train_set, test_set in stratified.split(data, data["label"]):
train = data.iloc[train_set]
test = data.iloc[test_set]
x_train = train.drop(["label"], axis=1)
y_train = train.loc[:, "label"]
x_test = test.drop(["label"], axis=1)
y_test = test.loc[:, "label"]
categorical_features = ["Existing_checking_account", "Credit_history", "Purpose", "Savings_account",
"Present_employment_since", "Personal_status_sex", "debtors", "Property",
"Other_installment_plans", "Housing", "Job"]
x_train.reset_index(drop=True, inplace=True)
y_train.reset_index(drop=True, inplace=True)
x_test.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)
return (x_train, y_train), (x_test, y_test)
def _modify_german_dataset(data):
def modify_Foreign_worker(value):
if value == 'A201':
return 1
elif value == 'A202':
return 0
else:
raise Exception('Bad value')
def modify_Telephone(value):
if value == 'A191':
return 0
elif value == 'A192':
return 1
else:
raise Exception('Bad value')
data['Foreign_worker'] = data['Foreign_worker'].apply(modify_Foreign_worker)
data['Telephone'] = data['Telephone'].apply(modify_Telephone)
def get_adult_dataset():
"""
Loads the UCI Adult dataset from `tests/datasets/adult` or downloads it if necessary.