Flake code cleanups

Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
This commit is contained in:
Maya Anderson 2023-09-20 09:23:22 +03:00
parent ad65f6f993
commit 0ee0bf05d6
4 changed files with 33 additions and 28 deletions

View file

@ -13,8 +13,7 @@ from apt.risk.data_assessment.dataset_assessment_manager import DatasetAssessmen
from apt.utils.dataset_utils import get_iris_dataset_np, get_diabetes_dataset_np, get_adult_dataset_pd, \
get_nursery_dataset_pd
from apt.utils.datasets import ArrayDataset
from data_assessment.dataset_attack_membership_classification import DatasetAttackConfigMembershipClassification, \
DatasetAttackMembershipClassification, DatasetAttackScoreMembershipClassification
from data_assessment.dataset_attack_membership_classification import DatasetAttackScoreMembershipClassification
from data_assessment.dataset_attack_membership_knn_probabilities import DatasetAttackScoreMembershipKnnProbabilities
from data_assessment.dataset_attack_whole_dataset_knn_distance import DatasetAttackScoreWholeDatasetKnnDistance
@ -22,7 +21,6 @@ MIN_SHARE = 0.5
MIN_ROC_AUC = 0.0
MIN_PRECISION = 0.0
NUM_SYNTH_SAMPLES = 100
NUM_SYNTH_COMPONENTS = 4
iris_dataset_np = get_iris_dataset_np()
@ -99,13 +97,13 @@ def test_risk_kde(name, data, dataset_type, mgr):
else:
raise ValueError('Pandas dataset missing a preprocessing step')
num_synth_samples = x_train.shape[0] # required by the chi test
synth_data = ArrayDataset(
kde(x_train.shape[0], n_components=num_synth_components, original_data=encoded))
# kde(NUM_SYNTH_SAMPLES, n_components=num_synth_components, original_data=encoded))
kde(num_synth_samples, n_components=num_synth_components, original_data=encoded))
original_data_members = ArrayDataset(encoded, y_train)
original_data_non_members = ArrayDataset(encoded_test, y_test)
dataset_name = 'kde' + str(NUM_SYNTH_SAMPLES) + name
dataset_name = 'kde' + str(num_synth_samples) + name
assess_privacy_and_validate_result(mgr, original_data_members, original_data_non_members, synth_data, dataset_name,
categorical_features)
@ -185,7 +183,6 @@ def filter_categorical(feature_names, return_feature_names: bool = True):
return list(np.flatnonzero(np.char.startswith(feature_name_strs, 'cat__')))
def assess_privacy_and_validate_result(dataset_assessment_manager, original_data_members, original_data_non_members,
synth_data, dataset_name, categorical_features):
attack_scores = dataset_assessment_manager.assess(original_data_members, original_data_non_members, synth_data,
@ -194,12 +191,12 @@ def assess_privacy_and_validate_result(dataset_assessment_manager, original_data
for i, (assessment_type, scores) in enumerate(attack_scores.items()):
if assessment_type == 'MembershipKnnProbabilities':
score_g: DatasetAttackScoreMembershipKnnProbabilities = scores[0]
assert(score_g.roc_auc_score > MIN_ROC_AUC)
assert(score_g.average_precision_score > MIN_PRECISION)
assert score_g.roc_auc_score > MIN_ROC_AUC
assert score_g.average_precision_score > MIN_PRECISION
elif assessment_type == 'WholeDatasetKnnDistance':
score_h: DatasetAttackScoreWholeDatasetKnnDistance = scores[0]
assert(score_h.share > MIN_SHARE)
assert score_h.share > MIN_SHARE
if assessment_type == 'MembershipClassification':
score_mc: DatasetAttackScoreMembershipClassification = scores[0]
assert(score_mc.synthetic_data_quality_warning is False)
assert (0 <= score_mc.normalized_ratio <= 1)
assert score_mc.synthetic_data_quality_warning is False
assert 0 <= score_mc.normalized_ratio <= 1