mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
Rename and move knn_learner in attack strategy utils for readability according to review
Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
This commit is contained in:
parent
185d9b9664
commit
69a9a8fa2b
3 changed files with 19 additions and 19 deletions
|
|
@ -30,10 +30,10 @@ class KNNAttackStrategyUtils(AttackStrategyUtils):
|
|||
if batch_size < 1:
|
||||
raise ValueError(f"When using batching batch_size should be > 0, and not {batch_size}")
|
||||
|
||||
def fit(self, dataset: ArrayDataset, knn_learner: NearestNeighbors):
|
||||
def fit(self, knn_learner: NearestNeighbors, dataset: ArrayDataset):
|
||||
knn_learner.fit(dataset.get_samples())
|
||||
|
||||
def find_knn(self, query_samples: ArrayDataset, knn_learner: NearestNeighbors, distance_processor=None):
|
||||
def find_knn(self, knn_learner: NearestNeighbors, query_samples: ArrayDataset, distance_processor=None):
|
||||
"""
|
||||
Nearest neighbor search function.
|
||||
:param query_samples: query samples, to which nearest neighbors are to be found
|
||||
|
|
|
|||
|
|
@ -76,13 +76,13 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole):
|
|||
super().__init__(original_data_members, original_data_non_members, synthetic_data, dataset_name,
|
||||
attack_strategy_utils, config)
|
||||
if config.compute_distance:
|
||||
self.nn_obj_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance,
|
||||
metric_params=config.distance_params)
|
||||
self.nn_obj_non_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance,
|
||||
metric_params=config.distance_params)
|
||||
self.knn_learner_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance,
|
||||
metric_params=config.distance_params)
|
||||
self.knn_learner_non_members = NearestNeighbors(n_neighbors=K, metric=config.compute_distance,
|
||||
metric_params=config.distance_params)
|
||||
else:
|
||||
self.nn_obj_members = NearestNeighbors(n_neighbors=K)
|
||||
self.nn_obj_non_members = NearestNeighbors(n_neighbors=K)
|
||||
self.knn_learner_members = NearestNeighbors(n_neighbors=K)
|
||||
self.knn_learner_non_members = NearestNeighbors(n_neighbors=K)
|
||||
|
||||
def assess_privacy(self) -> DatasetAttackScoreWholeDatasetKnnDistance:
|
||||
"""
|
||||
|
|
@ -98,7 +98,7 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole):
|
|||
n_non_members = len(self.original_data_non_members.get_samples())
|
||||
|
||||
# percent of synth. records closer to members,
|
||||
# and half those whose distance is similar to members and non-members
|
||||
# and distance ties are divided equally between members and non-members
|
||||
share = np.mean(member_distances < non_member_distances) + (n_members / (n_members + n_non_members)) * np.mean(
|
||||
member_distances == non_member_distances)
|
||||
score = DatasetAttackScoreWholeDatasetKnnDistance(self.dataset_name, share=share)
|
||||
|
|
@ -115,11 +115,11 @@ class DatasetAttackWholeDatasetKnnDistance(DatasetAttackWhole):
|
|||
neg_distances: distances of each synthetic data member from its nearest validation sample
|
||||
"""
|
||||
# nearest neighbor search
|
||||
self.attack_strategy_utils.fit(self.original_data_members, self.nn_obj_members)
|
||||
self.attack_strategy_utils.fit(self.original_data_non_members, self.nn_obj_non_members)
|
||||
self.attack_strategy_utils.fit(self.knn_learner_members, self.original_data_members)
|
||||
self.attack_strategy_utils.fit(self.knn_learner_non_members, self.original_data_non_members)
|
||||
|
||||
# distances of the synthetic data from the positive and negative samples (members and non-members)
|
||||
pos_distances = self.attack_strategy_utils.find_knn(self.synthetic_data, self.nn_obj_members)
|
||||
neg_distances = self.attack_strategy_utils.find_knn(self.synthetic_data, self.nn_obj_non_members)
|
||||
pos_distances = self.attack_strategy_utils.find_knn(self.knn_learner_members, self.synthetic_data)
|
||||
neg_distances = self.attack_strategy_utils.find_knn(self.knn_learner_non_members, self.synthetic_data)
|
||||
|
||||
return pos_distances, neg_distances
|
||||
|
|
|
|||
|
|
@ -76,10 +76,10 @@ class DatasetAttackPerRecordKnnProbabilities(DatasetAttackPerRecord):
|
|||
super().__init__(original_data_members, original_data_non_members, synthetic_data, dataset_name,
|
||||
attack_strategy_utils, config)
|
||||
if config.compute_distance:
|
||||
self.nn_obj = NearestNeighbors(n_neighbors=config.k, algorithm='auto', metric=config.compute_distance,
|
||||
metric_params=config.distance_params)
|
||||
self.knn_learner = NearestNeighbors(n_neighbors=config.k, algorithm='auto', metric=config.compute_distance,
|
||||
metric_params=config.distance_params)
|
||||
else:
|
||||
self.nn_obj = NearestNeighbors(n_neighbors=config.k, algorithm='auto')
|
||||
self.knn_learner = NearestNeighbors(n_neighbors=config.k, algorithm='auto')
|
||||
|
||||
def assess_privacy(self) -> DatasetAttackResultPerRecord:
|
||||
"""
|
||||
|
|
@ -98,14 +98,14 @@ class DatasetAttackPerRecordKnnProbabilities(DatasetAttackPerRecord):
|
|||
synthetic data generator based on the NN distances from the query samples to the synthetic data samples
|
||||
"""
|
||||
# nearest neighbor search
|
||||
self.attack_strategy_utils.fit(self.synthetic_data, self.nn_obj)
|
||||
self.attack_strategy_utils.fit(self.knn_learner, self.synthetic_data)
|
||||
|
||||
# positive query
|
||||
pos_proba = self.attack_strategy_utils.find_knn(self.original_data_members, self.nn_obj,
|
||||
pos_proba = self.attack_strategy_utils.find_knn(self.knn_learner, self.original_data_members,
|
||||
self.probability_per_sample)
|
||||
|
||||
# negative query
|
||||
neg_proba = self.attack_strategy_utils.find_knn(self.original_data_non_members, self.nn_obj,
|
||||
neg_proba = self.attack_strategy_utils.find_knn(self.knn_learner, self.original_data_non_members,
|
||||
self.probability_per_sample)
|
||||
|
||||
result = DatasetAttackResultPerRecord(self.dataset_name, positive_probabilities=pos_proba,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue