New model wrappers (#32)

* keras wrapper + blackbox classifier wrapper (fix #7)

* fix error in NCP calculation

* Update notebooks

* Fix #25 (incorrect attack_feature indexes for social feature in notebook)

* Consistent naming of internal parameters
This commit is contained in:
abigailgold 2022-05-12 15:44:29 +03:00 committed by GitHub
parent fd6be8e778
commit fe676fa426
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1407 additions and 656 deletions

View file

@ -6,7 +6,7 @@ from os import path, mkdir
from six.moves.urllib.request import urlretrieve
def get_iris_dataset(test_set: float = 0.3):
def get_iris_dataset_np(test_set: float = 0.3):
"""
Loads the Iris dataset from scikit-learn.
@ -29,7 +29,7 @@ def _load_iris(test_set_size: float = 0.3):
return (x_train, y_train), (x_test, y_test)
def get_diabetes_dataset(test_set: float = 0.3):
def get_diabetes_dataset_np(test_set: float = 0.3):
"""
Loads the Diabetes dataset from scikit-learn.
@ -52,7 +52,7 @@ def _load_diabetes(test_set_size: float = 0.3):
return (x_train, y_train), (x_test, y_test)
def get_german_credit_dataset(test_set: float = 0.3):
def get_german_credit_dataset_pd(test_set: float = 0.3):
"""
Loads the UCI German credit dataset from `tests/datasets/german` or downloads it from
https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/ if necessary.
@ -122,11 +122,16 @@ def _modify_german_dataset(data):
return 1
else:
raise Exception('Bad value')
def modify_label(value):
return value - 1
data['Foreign_worker'] = data['Foreign_worker'].apply(modify_Foreign_worker)
data['Telephone'] = data['Telephone'].apply(modify_Telephone)
data['label'] = data['label'].apply(modify_label)
def get_adult_dataset():
def get_adult_dataset_pd():
"""
Loads the UCI Adult dataset from `tests/datasets/adult` or downloads it from
https://archive.ics.uci.edu/ml/machine-learning-databases/adult/ if necessary.
@ -228,7 +233,7 @@ def _modify_adult_dataset(data):
return data.drop(['fnlwgt', 'education'], axis=1)
def get_nursery_dataset(raw: bool = True, test_set: float = 0.2, transform_social: bool = False):
def get_nursery_dataset_pd(raw: bool = True, test_set: float = 0.2, transform_social: bool = False):
"""
Loads the UCI Nursery dataset from `tests/datasets/nursery` or downloads it from
https://archive.ics.uci.edu/ml/machine-learning-databases/nursery/ if necessary.