diff --git a/apt/risk/data_assessment/dataset_assessment_manager.py b/apt/risk/data_assessment/dataset_assessment_manager.py index faeac69..fc2c879 100644 --- a/apt/risk/data_assessment/dataset_assessment_manager.py +++ b/apt/risk/data_assessment/dataset_assessment_manager.py @@ -1,15 +1,15 @@ from __future__ import annotations -from typing import Optional -from dataclasses import dataclass +from dataclasses import dataclass +from typing import Optional import pandas as pd +from apt.risk.data_assessment.dataset_attack_membership_knn_probabilities import \ + DatasetAttackConfigMembershipKnnProbabilities, DatasetAttackMembershipKnnProbabilities from apt.risk.data_assessment.dataset_attack_result import DatasetAttackScore, DEFAULT_DATASET_NAME from apt.risk.data_assessment.dataset_attack_whole_dataset_knn_distance import \ DatasetAttackConfigWholeDatasetKnnDistance, DatasetAttackWholeDatasetKnnDistance -from apt.risk.data_assessment.dataset_attack_membership_knn_probabilities import \ - DatasetAttackConfigMembershipKnnProbabilities, DatasetAttackMembershipKnnProbabilities from apt.utils.datasets import ArrayDataset @@ -49,23 +49,22 @@ class DatasetAssessmentManager: """ config_gl = DatasetAttackConfigMembershipKnnProbabilities(use_batches=False, generate_plot=self.config.generate_plots) - mgr = DatasetAttackMembershipKnnProbabilities(original_data_members, - original_data_non_members, - synthetic_data, - config_gl, - dataset_name) + attack_gl = DatasetAttackMembershipKnnProbabilities(original_data_members, + original_data_non_members, + synthetic_data, + config_gl, + dataset_name) - score_g = mgr.assess_privacy() - self.attack_scores_per_record_knn_probabilities.append(score_g) + score_gl = attack_gl.assess_privacy() + self.attack_scores_per_record_knn_probabilities.append(score_gl) config_h = DatasetAttackConfigWholeDatasetKnnDistance(use_batches=False) - mgr_h = DatasetAttackWholeDatasetKnnDistance(original_data_members, original_data_non_members, synthetic_data, - config_h, - dataset_name) + attack_h = DatasetAttackWholeDatasetKnnDistance(original_data_members, original_data_non_members, + synthetic_data, config_h, dataset_name) - score_h = mgr_h.assess_privacy() + score_h = attack_h.assess_privacy() self.attack_scores_whole_dataset_knn_distance.append(score_h) - return [score_g, score_h] + return [score_gl, score_h] def dump_all_scores_to_files(self): if self.config.persist_reports: diff --git a/apt/risk/data_assessment/dataset_attack_result.py b/apt/risk/data_assessment/dataset_attack_result.py index 55a4e2f..0ed0bd4 100644 --- a/apt/risk/data_assessment/dataset_attack_result.py +++ b/apt/risk/data_assessment/dataset_attack_result.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Optional import numpy as np