Rename and move knn_learner in attack strategy utils for readability according to review

Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
This commit is contained in:
Maya Anderson 2023-03-08 10:28:08 +02:00
parent 185d9b9664
commit 69a9a8fa2b
3 changed files with 19 additions and 19 deletions

View file

@ -30,10 +30,10 @@ class KNNAttackStrategyUtils(AttackStrategyUtils):
if batch_size < 1:
raise ValueError(f"When using batching batch_size should be > 0, and not {batch_size}")
def fit(self, dataset: ArrayDataset, knn_learner: NearestNeighbors):
def fit(self, knn_learner: NearestNeighbors, dataset: ArrayDataset):
knn_learner.fit(dataset.get_samples())
def find_knn(self, query_samples: ArrayDataset, knn_learner: NearestNeighbors, distance_processor=None):
def find_knn(self, knn_learner: NearestNeighbors, query_samples: ArrayDataset, distance_processor=None):
"""
Nearest neighbor search function.
:param query_samples: query samples, to which nearest neighbors are to be found

View file

@ -76,13 +76,13 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole):
super().__init__(original_data_members, original_data_non_members, synthetic_data, dataset_name,
attack_strategy_utils, config)
if config.compute_distance:
self.nn_obj_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance,
metric_params=config.distance_params)
self.nn_obj_non_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance,
metric_params=config.distance_params)
self.knn_learner_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance,
metric_params=config.distance_params)
self.knn_learner_non_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance,
metric_params=config.distance_params)
else:
self.nn_obj_members = NearestNeighbors(n_neighbors=K)
self.nn_obj_non_members = NearestNeighbors(n_neighbors=K)
self.knn_learner_members = NearestNeighbors(n_neighbors=K)
self.knn_learner_non_members = NearestNeighbors(n_neighbors=K)
def assess_privacy(self) -> DatasetAttackScoreWholeDatasetKnnDistance:
"""
@ -98,7 +98,7 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole):
n_non_members = len(self.original_data_non_members.get_samples())
# percent of synth. records closer to members,
# and half those whose distance is similar to members and non-members
# and distance ties are divided equally between members and non-members
share = np.mean(member_distances < non_member_distances) + (n_members / (n_members + n_non_members)) * np.mean(
member_distances == non_member_distances)
score = DatasetAttackScoreWholeDatasetKnnDistance(self.dataset_name, share=share)
@ -115,11 +115,11 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole):
neg_distances: distances of each synthetic data member from its nearest validation sample
"""
# nearest neighbor search
self.attack_strategy_utils.fit(self.original_data_members, self.nn_obj_members)
self.attack_strategy_utils.fit(self.original_data_non_members, self.nn_obj_non_members)
self.attack_strategy_utils.fit(self.knn_learner_members, self.original_data_members)
self.attack_strategy_utils.fit(self.knn_learner_non_members, self.original_data_non_members)
# distances of the synthetic data from the positive and negative samples (members and non-members)
pos_distances = self.attack_strategy_utils.find_knn(self.synthetic_data, self.nn_obj_members)
neg_distances = self.attack_strategy_utils.find_knn(self.synthetic_data, self.nn_obj_non_members)
pos_distances = self.attack_strategy_utils.find_knn(self.knn_learner_members, self.synthetic_data)
neg_distances = self.attack_strategy_utils.find_knn(self.knn_learner_non_members, self.synthetic_data)
return pos_distances, neg_distances

View file

@ -76,10 +76,10 @@ class DatasetAttackPerRecordKnnProbabilities(DatasetAttackPerRecord):
super().__init__(original_data_members, original_data_non_members, synthetic_data, dataset_name,
attack_strategy_utils, config)
if config.compute_distance:
self.nn_obj = NearestNeighbors(n_neighbors=config.k, algorithm='auto', metric=config.compute_distance,
metric_params=config.distance_params)
self.knn_learner = NearestNeighbors(n_neighbors=config.k, algorithm='auto', metric=config.compute_distance,
metric_params=config.distance_params)
else:
self.nn_obj = NearestNeighbors(n_neighbors=config.k, algorithm='auto')
self.knn_learner = NearestNeighbors(n_neighbors=config.k, algorithm='auto')
def assess_privacy(self) -> DatasetAttackResultPerRecord:
"""
@ -98,14 +98,14 @@ class DatasetAttackPerRecordKnnProbabilities(DatasetAttackPerRecord):
synthetic data generator based on the NN distances from the query samples to the synthetic data samples
"""
# nearest neighbor search
self.attack_strategy_utils.fit(self.synthetic_data, self.nn_obj)
self.attack_strategy_utils.fit(self.knn_learner, self.synthetic_data)
# positive query
pos_proba = self.attack_strategy_utils.find_knn(self.original_data_members, self.nn_obj,
pos_proba = self.attack_strategy_utils.find_knn(self.knn_learner, self.original_data_members,
self.probability_per_sample)
# negative query
neg_proba = self.attack_strategy_utils.find_knn(self.original_data_non_members, self.nn_obj,
neg_proba = self.attack_strategy_utils.find_knn(self.knn_learner, self.original_data_non_members,
self.probability_per_sample)
result = DatasetAttackResultPerRecord(self.dataset_name, positive_probabilities=pos_proba,