diff --git a/apt/risk/data_assessment/attack_strategy_utils.py b/apt/risk/data_assessment/attack_strategy_utils.py index 415dae5..8babf9a 100644 --- a/apt/risk/data_assessment/attack_strategy_utils.py +++ b/apt/risk/data_assessment/attack_strategy_utils.py @@ -30,10 +30,10 @@ class KNNAttackStrategyUtils(AttackStrategyUtils): if batch_size < 1: raise ValueError(f"When using batching batch_size should be > 0, and not {batch_size}") - def fit(self, dataset: ArrayDataset, knn_learner: NearestNeighbors): + def fit(self, knn_learner: NearestNeighbors, dataset: ArrayDataset): knn_learner.fit(dataset.get_samples()) - def find_knn(self, query_samples: ArrayDataset, knn_learner: NearestNeighbors, distance_processor=None): + def find_knn(self, knn_learner: NearestNeighbors, query_samples: ArrayDataset, distance_processor=None): """ Nearest neighbor search function. :param query_samples: query samples, to which nearest neighbors are to be found diff --git a/apt/risk/data_assessment/dataset_attack_whole_dataset_knn_distance.py b/apt/risk/data_assessment/dataset_attack_whole_dataset_knn_distance.py index 94a95cb..ddae72a 100644 --- a/apt/risk/data_assessment/dataset_attack_whole_dataset_knn_distance.py +++ b/apt/risk/data_assessment/dataset_attack_whole_dataset_knn_distance.py @@ -76,13 +76,13 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole): super().__init__(original_data_members, original_data_non_members, synthetic_data, dataset_name, attack_strategy_utils, config) if config.compute_distance: - self.nn_obj_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance, - metric_params=config.distance_params) - self.nn_obj_non_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance, - metric_params=config.distance_params) + self.knn_learner_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance, + metric_params=config.distance_params) + self.knn_learner_non_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance, + metric_params=config.distance_params) else: - self.nn_obj_members = NearestNeighbors(n_neighbors=K) - self.nn_obj_non_members = NearestNeighbors(n_neighbors=K) + self.knn_learner_members = NearestNeighbors(n_neighbors=K) + self.knn_learner_non_members = NearestNeighbors(n_neighbors=K) def assess_privacy(self) -> DatasetAttackScoreWholeDatasetKnnDistance: """ @@ -98,7 +98,7 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole): n_non_members = len(self.original_data_non_members.get_samples()) # percent of synth. records closer to members, - # and half those whose distance is similar to members and non-members + # and distance ties are divided equally between members and non-members share = np.mean(member_distances < non_member_distances) + (n_members / (n_members + n_non_members)) * np.mean( member_distances == non_member_distances) score = DatasetAttackScoreWholeDatasetKnnDistance(self.dataset_name, share=share) @@ -115,11 +115,11 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole): neg_distances: distances of each synthetic data member from its nearest validation sample """ # nearest neighbor search - self.attack_strategy_utils.fit(self.original_data_members, self.nn_obj_members) - self.attack_strategy_utils.fit(self.original_data_non_members, self.nn_obj_non_members) + self.attack_strategy_utils.fit(self.knn_learner_members, self.original_data_members) + self.attack_strategy_utils.fit(self.knn_learner_non_members, self.original_data_non_members) # distances of the synthetic data from the positive and negative samples (members and non-members) - pos_distances = self.attack_strategy_utils.find_knn(self.synthetic_data, self.nn_obj_members) - neg_distances = self.attack_strategy_utils.find_knn(self.synthetic_data, self.nn_obj_non_members) + pos_distances = self.attack_strategy_utils.find_knn(self.knn_learner_members, self.synthetic_data) + neg_distances = self.attack_strategy_utils.find_knn(self.knn_learner_non_members, self.synthetic_data) return pos_distances, neg_distances diff --git a/apt/risk/data_assessment/per_record_knn_probabilities_dataset_attack_.py b/apt/risk/data_assessment/per_record_knn_probabilities_dataset_attack_.py index 10929c7..057961d 100644 --- a/apt/risk/data_assessment/per_record_knn_probabilities_dataset_attack_.py +++ b/apt/risk/data_assessment/per_record_knn_probabilities_dataset_attack_.py @@ -76,10 +76,10 @@ class DatasetAttackPerRecordKnnProbabilities(DatasetAttackPerRecord): super().__init__(original_data_members, original_data_non_members, synthetic_data, dataset_name, attack_strategy_utils, config) if config.compute_distance: - self.nn_obj = NearestNeighbors(n_neighbors=config.k, algorithm='auto', metric=config.compute_distance, - metric_params=config.distance_params) + self.knn_learner = NearestNeighbors(n_neighbors=config.k, algorithm='auto', metric=config.compute_distance, + metric_params=config.distance_params) else: - self.nn_obj = NearestNeighbors(n_neighbors=config.k, algorithm='auto') + self.knn_learner = NearestNeighbors(n_neighbors=config.k, algorithm='auto') def assess_privacy(self) -> DatasetAttackResultPerRecord: """ @@ -98,14 +98,14 @@ class DatasetAttackPerRecordKnnProbabilities(DatasetAttackPerRecord): synthetic data generator based on the NN distances from the query samples to the synthetic data samples """ # nearest neighbor search - self.attack_strategy_utils.fit(self.synthetic_data, self.nn_obj) + self.attack_strategy_utils.fit(self.knn_learner, self.synthetic_data) # positive query - pos_proba = self.attack_strategy_utils.find_knn(self.original_data_members, self.nn_obj, + pos_proba = self.attack_strategy_utils.find_knn(self.knn_learner, self.original_data_members, self.probability_per_sample) # negative query - neg_proba = self.attack_strategy_utils.find_knn(self.original_data_non_members, self.nn_obj, + neg_proba = self.attack_strategy_utils.find_knn(self.knn_learner, self.original_data_non_members, self.probability_per_sample) result = DatasetAttackResultPerRecord(self.dataset_name, positive_probabilities=pos_proba,