mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-23 15:48:06 +02:00
Flake code cleanups
Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
This commit is contained in:
parent
ad65f6f993
commit
0ee0bf05d6
4 changed files with 33 additions and 28 deletions
|
|
@ -13,8 +13,7 @@ from apt.risk.data_assessment.dataset_assessment_manager import DatasetAssessmen
|
|||
from apt.utils.dataset_utils import get_iris_dataset_np, get_diabetes_dataset_np, get_adult_dataset_pd, \
|
||||
get_nursery_dataset_pd
|
||||
from apt.utils.datasets import ArrayDataset
|
||||
from data_assessment.dataset_attack_membership_classification import DatasetAttackConfigMembershipClassification, \
|
||||
DatasetAttackMembershipClassification, DatasetAttackScoreMembershipClassification
|
||||
from data_assessment.dataset_attack_membership_classification import DatasetAttackScoreMembershipClassification
|
||||
from data_assessment.dataset_attack_membership_knn_probabilities import DatasetAttackScoreMembershipKnnProbabilities
|
||||
from data_assessment.dataset_attack_whole_dataset_knn_distance import DatasetAttackScoreWholeDatasetKnnDistance
|
||||
|
||||
|
|
@ -22,7 +21,6 @@ MIN_SHARE = 0.5
|
|||
MIN_ROC_AUC = 0.0
|
||||
MIN_PRECISION = 0.0
|
||||
|
||||
NUM_SYNTH_SAMPLES = 100
|
||||
NUM_SYNTH_COMPONENTS = 4
|
||||
|
||||
iris_dataset_np = get_iris_dataset_np()
|
||||
|
|
@ -99,13 +97,13 @@ def test_risk_kde(name, data, dataset_type, mgr):
|
|||
else:
|
||||
raise ValueError('Pandas dataset missing a preprocessing step')
|
||||
|
||||
num_synth_samples = x_train.shape[0] # required by the chi test
|
||||
synth_data = ArrayDataset(
|
||||
kde(x_train.shape[0], n_components=num_synth_components, original_data=encoded))
|
||||
# kde(NUM_SYNTH_SAMPLES, n_components=num_synth_components, original_data=encoded))
|
||||
kde(num_synth_samples, n_components=num_synth_components, original_data=encoded))
|
||||
original_data_members = ArrayDataset(encoded, y_train)
|
||||
original_data_non_members = ArrayDataset(encoded_test, y_test)
|
||||
|
||||
dataset_name = 'kde' + str(NUM_SYNTH_SAMPLES) + name
|
||||
dataset_name = 'kde' + str(num_synth_samples) + name
|
||||
assess_privacy_and_validate_result(mgr, original_data_members, original_data_non_members, synth_data, dataset_name,
|
||||
categorical_features)
|
||||
|
||||
|
|
@ -185,7 +183,6 @@ def filter_categorical(feature_names, return_feature_names: bool = True):
|
|||
return list(np.flatnonzero(np.char.startswith(feature_name_strs, 'cat__')))
|
||||
|
||||
|
||||
|
||||
def assess_privacy_and_validate_result(dataset_assessment_manager, original_data_members, original_data_non_members,
|
||||
synth_data, dataset_name, categorical_features):
|
||||
attack_scores = dataset_assessment_manager.assess(original_data_members, original_data_non_members, synth_data,
|
||||
|
|
@ -194,12 +191,12 @@ def assess_privacy_and_validate_result(dataset_assessment_manager, original_data
|
|||
for i, (assessment_type, scores) in enumerate(attack_scores.items()):
|
||||
if assessment_type == 'MembershipKnnProbabilities':
|
||||
score_g: DatasetAttackScoreMembershipKnnProbabilities = scores[0]
|
||||
assert(score_g.roc_auc_score > MIN_ROC_AUC)
|
||||
assert(score_g.average_precision_score > MIN_PRECISION)
|
||||
assert score_g.roc_auc_score > MIN_ROC_AUC
|
||||
assert score_g.average_precision_score > MIN_PRECISION
|
||||
elif assessment_type == 'WholeDatasetKnnDistance':
|
||||
score_h: DatasetAttackScoreWholeDatasetKnnDistance = scores[0]
|
||||
assert(score_h.share > MIN_SHARE)
|
||||
assert score_h.share > MIN_SHARE
|
||||
if assessment_type == 'MembershipClassification':
|
||||
score_mc: DatasetAttackScoreMembershipClassification = scores[0]
|
||||
assert(score_mc.synthetic_data_quality_warning is False)
|
||||
assert (0 <= score_mc.normalized_ratio <= 1)
|
||||
assert score_mc.synthetic_data_quality_warning is False
|
||||
assert 0 <= score_mc.normalized_ratio <= 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue