Address review comments - make dataset_name optional, fix filename. Fix score serialization - don't serialize the result.

Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
This commit is contained in:
Maya Anderson 2023-03-09 22:38:39 +02:00
parent 3ae64054f8
commit a122976807
6 changed files with 73 additions and 58 deletions

View file

@ -60,8 +60,8 @@ def test_risk_anonymization(name, data, dataset_type, k, mgr):
original_data_members = ArrayDataset(preprocessed_x_train, y_train)
original_data_non_members = ArrayDataset(preprocessed_x_test, y_test)
score_g, score_h = mgr.assess(original_data_members, original_data_non_members, anonymized_data,
f'anon_k{k}_{name}')
[score_g, score_h] = mgr.assess(original_data_members, original_data_non_members, anonymized_data,
f'anon_k{k}_{name}')
assert (score_g.roc_auc_score > 0.5)
assert (score_g.average_precision_score > 0.5)
@ -96,8 +96,8 @@ def test_risk_kde(name, data, dataset_type, mgr):
original_data_members = ArrayDataset(encoded, y_train)
original_data_non_members = ArrayDataset(encoded_test, y_test)
score_g, score_h = mgr.assess(original_data_members, original_data_non_members, synth_data,
'kde' + str(NUM_SYNTH_SAMPLES) + name)
[score_g, score_h] = mgr.assess(original_data_members, original_data_non_members, synth_data,
'kde' + str(NUM_SYNTH_SAMPLES) + name)
assert (score_g.roc_auc_score > 0.5)
assert (score_g.average_precision_score > 0.5)