mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
Address review comments:
extract common code, add comments, change ellipsis to pass Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
This commit is contained in:
parent
4a024d8d1e
commit
e7e725ea80
6 changed files with 95 additions and 86 deletions
|
|
@ -11,7 +11,7 @@ class AttackStrategyUtils(abc.ABC):
|
|||
"""
|
||||
Abstract base class for common utilities of various privacy attack strategies.
|
||||
"""
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
class KNNAttackStrategyUtils(AttackStrategyUtils):
|
||||
|
|
@ -19,7 +19,7 @@ class KNNAttackStrategyUtils(AttackStrategyUtils):
|
|||
Common utilities for attack strategy based on KNN distances.
|
||||
"""
|
||||
|
||||
def __init__(self, k: int, use_batches: bool = False, batch_size: int = 0) -> None:
|
||||
def __init__(self, k: int, use_batches: bool = False, batch_size: int = 10) -> None:
|
||||
"""
|
||||
:param k: How many nearest neighbors to search
|
||||
:param use_batches: Use batches with a progress meter or not when finding KNNs for query set
|
||||
|
|
@ -37,9 +37,9 @@ class KNNAttackStrategyUtils(AttackStrategyUtils):
|
|||
|
||||
def find_knn(self, query_samples: ArrayDataset, knn_learner: NearestNeighbors, distance_processor=None):
|
||||
"""
|
||||
Main nearest neighbor search function on synthetic data.
|
||||
:param query_samples: query samples
|
||||
:param knn_learner: unsupervised learner for implementing neighbor searches
|
||||
Nearest neighbor search function.
|
||||
:param query_samples: query samples, to which nearest neighbors are to be found
|
||||
:param knn_learner: unsupervised learner for implementing neighbor searches, after it was fitted
|
||||
:param distance_processor: function for processing the distance into another more relevant metric per sample.
|
||||
Its input is an array representing distances (the distances returned by NearestNeighbors.kneighbors() ),
|
||||
and the output should be another array with distance-based values that enable to compute the final score
|
||||
|
|
@ -55,7 +55,7 @@ class KNNAttackStrategyUtils(AttackStrategyUtils):
|
|||
else:
|
||||
return distances
|
||||
|
||||
probabilities = []
|
||||
distances = []
|
||||
for i in tqdm(range(len(samples) // self.batch_size)):
|
||||
x_batch = samples[i * self.batch_size:(i + 1) * self.batch_size]
|
||||
x_batch = np.reshape(x_batch, [self.batch_size, -1])
|
||||
|
|
@ -65,8 +65,8 @@ class KNNAttackStrategyUtils(AttackStrategyUtils):
|
|||
|
||||
# The probability of each sample to be generated
|
||||
if distance_processor:
|
||||
probability_per_sample_batch = distance_processor(dist_batch)
|
||||
probabilities.append(probability_per_sample_batch)
|
||||
distance_based_metric_per_sample_batch = distance_processor(dist_batch)
|
||||
distances.append(distance_based_metric_per_sample_batch)
|
||||
else:
|
||||
probabilities.append(dist_batch)
|
||||
return np.concatenate(probabilities)
|
||||
distances.append(dist_batch)
|
||||
return np.concatenate(distances)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from apt.utils.datasets import ArrayDataset
|
|||
|
||||
@dataclass
|
||||
class DatasetAssessmentManagerConfig:
|
||||
persist_reports: bool = True
|
||||
persist_reports: bool = False
|
||||
generate_plots: bool = False
|
||||
|
||||
|
||||
|
|
@ -25,15 +25,14 @@ class DatasetAssessmentManager:
|
|||
|
||||
def __init__(self, config: Optional[DatasetAssessmentManagerConfig] = DatasetAssessmentManagerConfig) -> None:
|
||||
"""
|
||||
:param config: Configuration parameters to guide the assessment process such as which attack
|
||||
frameworks to use, optional
|
||||
:param config: Configuration parameters to guide the dataset assessment process
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
def assess(self, original_data_members: ArrayDataset, original_data_non_members: ArrayDataset,
|
||||
synthetic_data: ArrayDataset, dataset_name: str) -> (
|
||||
synthetic_data: ArrayDataset, dataset_name: str = "dataset") -> (
|
||||
DatasetAttackScoreGanLeaks, DatasetAttackScoreHoldout):
|
||||
config_gl = DatasetAttackGanLeaksConfig(use_batches=False)
|
||||
config_gl = DatasetAttackGanLeaksConfig(use_batches=False, k=5)
|
||||
mgr = DatasetAttackGanLeaks(original_data_members,
|
||||
original_data_non_members,
|
||||
synthetic_data,
|
||||
|
|
@ -44,7 +43,7 @@ class DatasetAssessmentManager:
|
|||
score_g = mgr.calculate_privacy_score(result, generate_plot=self.config.generate_plots)
|
||||
self.gan_leaks_attack_scores.append(score_g)
|
||||
|
||||
config_h = DatasetAttackHoldoutConfig(use_batches=False)
|
||||
config_h = DatasetAttackHoldoutConfig(use_batches=False, k=5)
|
||||
mgr_h = DatasetAttackHoldout(original_data_members, original_data_non_members, synthetic_data,
|
||||
dataset_name,
|
||||
config_h)
|
||||
|
|
@ -54,7 +53,7 @@ class DatasetAssessmentManager:
|
|||
return score_g, score_h
|
||||
|
||||
def dump_all_scores_to_files(self):
|
||||
if self.config.generate_plots:
|
||||
if self.config.persist_reports:
|
||||
results_log_file = "_results.log.csv"
|
||||
self.dump_scores_to_file(self.gan_leaks_attack_scores, "gan_leaks" + results_log_file, True)
|
||||
self.dump_scores_to_file(self.holdout_attack_scores, "holdout" + results_log_file, True)
|
||||
|
|
|
|||
|
|
@ -15,25 +15,29 @@ from apt.risk.data_assessment.dataset_attack_result import DatasetAttackScore, D
|
|||
from apt.utils.datasets import ArrayDataset
|
||||
|
||||
|
||||
class Config:
|
||||
class Config(abc.ABC):
|
||||
"""
|
||||
The base class for dataset attack configurations
|
||||
"""
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
class DatasetAttack(abc.ABC):
|
||||
"""
|
||||
The interface for performing privacy risk assessment for synthetic datasets.
|
||||
The interface for performing privacy attack for risk assessment for synthetic datasets to be used in AI models.
|
||||
The original data members (training data) and non-members (the holdout data) should be available.
|
||||
For reliability, all the datasets should be preprocessed and normalized.
|
||||
"""
|
||||
|
||||
def __init__(self, original_data_members: ArrayDataset, original_data_non_members: ArrayDataset,
|
||||
synthetic_data: ArrayDataset, dataset_name: str, attack_strategy_utils: AttackStrategyUtils,
|
||||
config: Optional[Config] = Config()) -> None:
|
||||
"""
|
||||
:param original_data_members: A container for the training original samples and labels
|
||||
:param original_data_non_members: A container for the holdout original samples and labels
|
||||
:param synthetic_data: A container for the synthetic samples and labels
|
||||
:param original_data_members: A container for the training original samples and labels,
|
||||
only samples are used in the assessment
|
||||
:param original_data_non_members: A container for the holdout original samples and labels,
|
||||
only samples are used in the assessment
|
||||
:param synthetic_data: A container for the synthetic samples and labels, only samples are used in the assessment
|
||||
:param dataset_name: A name to identify the dataset under attack
|
||||
:param attack_strategy_utils: Utils for use with the attack strategy
|
||||
:param config: Configuration parameters to guide the assessment process such as which attack
|
||||
|
|
@ -52,10 +56,10 @@ class DatasetAttack(abc.ABC):
|
|||
"""
|
||||
Assess the privacy of the dataset
|
||||
:return:
|
||||
result: Union[DatasetAttackScore, DatasetAssessmentResult] can be either the final privacy attack score,
|
||||
result: Union[DatasetAttackScore, DatasetAttackResult] can be either the final privacy attack score,
|
||||
or an intermediate attack result, which can be translated into a privacy score if needed
|
||||
"""
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
class DatasetAttackPerRecord(DatasetAttack):
|
||||
|
|
@ -68,9 +72,9 @@ class DatasetAttackPerRecord(DatasetAttack):
|
|||
"""
|
||||
Assess the privacy of the dataset
|
||||
:return:
|
||||
result: DatasetAssessmentResult
|
||||
result: DatasetAttackResultPerRecord
|
||||
"""
|
||||
...
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def calculate_privacy_score(self, dataset_attack_result: DatasetAttackResultPerRecord,
|
||||
|
|
@ -80,13 +84,13 @@ class DatasetAttackPerRecord(DatasetAttack):
|
|||
:return:
|
||||
result: DatasetAttackScore
|
||||
"""
|
||||
...
|
||||
pass
|
||||
|
||||
def plot_roc_curve(self, pos_probabilities, neg_probabilities, name_prefix=""):
|
||||
"""
|
||||
Plot ROC curve
|
||||
:param pos_probabilities: loss of the positive samples, the training data
|
||||
:param neg_probabilities: loss of the negative samples, the hold-out data
|
||||
:param pos_probabilities: probability estimates of the positive samples, the training data
|
||||
:param neg_probabilities: probability estimates of the negative samples, the hold-out data
|
||||
:param name_prefix: name prefix for the ROC curve plot
|
||||
"""
|
||||
labels = np.concatenate((np.zeros((len(neg_probabilities),)), np.ones((len(pos_probabilities),))))
|
||||
|
|
@ -98,9 +102,9 @@ class DatasetAttackPerRecord(DatasetAttack):
|
|||
plt.savefig(f'{name_prefix}{self.dataset_name}_roc_curve.png')
|
||||
|
||||
@staticmethod
|
||||
def calculate_roc_score(pos_probabilities, neg_probabilities):
|
||||
def calculate_metrics(pos_probabilities, neg_probabilities):
|
||||
"""
|
||||
Plot ROC curve
|
||||
Calculate attack performance metrics
|
||||
:param pos_probabilities: probability estimates of the positive samples, the training data
|
||||
:param neg_probabilities: probability estimates of the negative samples, the hold-out data
|
||||
:return:
|
||||
|
|
@ -110,7 +114,7 @@ class DatasetAttackPerRecord(DatasetAttack):
|
|||
auc: area under the Receiver Operating Characteristic Curve
|
||||
ap: average precision score
|
||||
"""
|
||||
labels = np.concatenate((np.zeros((len(neg_probabilities),)), np.ones((len(pos_probabilities),))))
|
||||
labels = np.concatenate((np.zeros((len(neg_probabilities),)), np.ones((len(pos_probabilities)))))
|
||||
results = np.concatenate((neg_probabilities, pos_probabilities))
|
||||
fpr, tpr, threshold = metrics.roc_curve(labels, results, pos_label=1)
|
||||
auc = metrics.roc_auc_score(labels, results)
|
||||
|
|
@ -128,6 +132,6 @@ class DatasetAttackWhole(DatasetAttack):
|
|||
"""
|
||||
Assess the privacy of the dataset
|
||||
:return:
|
||||
result: DatasetAssessmentResult
|
||||
result: DatasetAttackScore
|
||||
"""
|
||||
...
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ published in Proceedings of the 2020 ACM SIGSAC Conference on Computer and Commu
|
|||
https://doi.org/10.1145/3372297.3417238 and its implementation in https://github.com/DingfanChen/GAN-Leaks.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable
|
||||
|
||||
import numpy as np
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
|
@ -25,23 +25,25 @@ class DatasetAttackGanLeaksConfig(Config):
|
|||
use_batches: Divide query samples into batches or not.
|
||||
batch_size: Query sample batch size.
|
||||
compute_distance: A callable function, which takes two arrays representing 1D vectors as inputs and must return
|
||||
one value indicating the distance between those vectors. See sklearn.neighbors.NearestNeighbors documentation.
|
||||
batch_size: Additional keyword arguments for the distance computation function.
|
||||
one value indicating the distance between those vectors.
|
||||
See 'metric' parameter in sklearn.neighbors.NearestNeighbors documentation.
|
||||
distance_params: Additional keyword arguments for the distance computation function, see 'metric_params' in
|
||||
sklearn.neighbors.NearestNeighbors documentation.
|
||||
"""
|
||||
k: int = 1
|
||||
use_batches: bool = False
|
||||
batch_size: int = 10
|
||||
compute_distance: callable = None
|
||||
compute_distance: Callable = None
|
||||
distance_params: dict = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttackScoreGanLeaks(DatasetAttackScore):
|
||||
"""Configuration for DatasetAttackGanLeaks.
|
||||
"""DatasetAttackGanLeaks privacy score.
|
||||
Attributes
|
||||
----------
|
||||
roc_auc_score : the share of synthetic records closer to the training than the holdout dataset
|
||||
average_precision_score:
|
||||
roc_auc_score : the area under the receiver operating characteristic curve (AUC ROC) to evaluate the attack performance.
|
||||
average_precision_score: the proportion of Predicted Positive cases that are correctly Real Positives (members)
|
||||
assessment_type : assessment type is 'GANLeaks', to be used in reports
|
||||
"""
|
||||
roc_auc_score: float
|
||||
|
|
@ -51,9 +53,9 @@ class DatasetAttackScoreGanLeaks(DatasetAttackScore):
|
|||
|
||||
class DatasetAttackGanLeaks(DatasetAttackPerRecord):
|
||||
"""
|
||||
Privacy risk assessment for synthetic datasets based Black-Box MIA attack using distances of
|
||||
Privacy risk assessment for synthetic datasets based on Black-Box MIA attack using distances of
|
||||
members (training set) and non-members (holdout set) from their nearest neighbors in the synthetic dataset.
|
||||
The area under the receiver operating characteristic curve (AUCROC) gives the privacy risk measure.
|
||||
The area under the receiver operating characteristic curve (AUC ROC) gives the privacy risk measure.
|
||||
"""
|
||||
|
||||
def __init__(self, original_data_members: ArrayDataset, original_data_non_members: ArrayDataset,
|
||||
|
|
@ -64,8 +66,7 @@ class DatasetAttackGanLeaks(DatasetAttackPerRecord):
|
|||
:param original_data_non_members: A container for the holdout original samples and labels
|
||||
:param synthetic_data: A container for the synthetic samples and labels
|
||||
:param dataset_name: A name to identify this dataset
|
||||
:param config: Configuration parameters to guide the assessment process such as which attack
|
||||
frameworks to use, optional
|
||||
:param config: Configuration parameters to guide the attack, optional
|
||||
"""
|
||||
attack_strategy_utils = KNNAttackStrategyUtils(config.k, config.use_batches, config.batch_size)
|
||||
super().__init__(original_data_members, original_data_non_members, synthetic_data, dataset_name,
|
||||
|
|
@ -78,9 +79,19 @@ class DatasetAttackGanLeaks(DatasetAttackPerRecord):
|
|||
|
||||
def assess_privacy(self) -> DatasetAttackResultPerRecord:
|
||||
"""
|
||||
Calculate probabilities of positive and negative samples to be generated by the synthetic data generator
|
||||
:return:
|
||||
:result of the attack, based on the NN distances from the query samples to the synthetic data samples
|
||||
Membership Inference Attack which calculates probabilities of positive and negative samples to be generated by
|
||||
the synthetic data generator.
|
||||
The assumption is that since the generative model is trained to approximate the training data distribution
|
||||
then the probability of a sample to be a member of the training data should be proportional to the probability
|
||||
that the query sample can be generated by the generative model.
|
||||
The assumption is that if the probability that the query sample is generated by the generative model is large,
|
||||
it is more likely that the query sample was used to train the generative model. This probability is approximated
|
||||
by the Parzen window density estimation in 'probability_per_sample()', computed from the NN distances from the
|
||||
query samples to the synthetic data samples.
|
||||
|
||||
:return
|
||||
:result of the attack with the probabilities of positive and negative samples to be generated by the
|
||||
synthetic data generator based on the NN distances from the query samples to the synthetic data samples
|
||||
"""
|
||||
# nearest neighbor search
|
||||
self.attack_strategy_utils.fit(self.synthetic_data, self.nn_obj)
|
||||
|
|
@ -100,16 +111,17 @@ class DatasetAttackGanLeaks(DatasetAttackPerRecord):
|
|||
def calculate_privacy_score(self, dataset_attack_result: DatasetAttackResultPerRecord,
|
||||
generate_plot=False) -> DatasetAttackScore:
|
||||
"""
|
||||
Calculate probabilities of positive and negative samples to be generated by the synthetic data generator
|
||||
Evaluate privacy score from the probabilities of positive and negative samples to be generated by the synthetic
|
||||
data generator. The probabilities are computed by the 'assess_privacy()' method.
|
||||
:param dataset_attack_result attack result containing probabilities of positive and negative samples to be
|
||||
generated by the synthetic data generator
|
||||
:param generate_plot generate AUC ROC curve plot and persist it
|
||||
:return:
|
||||
:score of the attack, based on distance-based probabilities
|
||||
:return
|
||||
:score of the attack, based on distance-based probabilities - mainly the ROC AUC score
|
||||
"""
|
||||
pos_proba, neg_proba = \
|
||||
dataset_attack_result.positive_probabilities, dataset_attack_result.negative_probabilities
|
||||
fpr, tpr, threshold, auc, ap = self.calculate_roc_score(pos_proba, neg_proba)
|
||||
fpr, tpr, threshold, auc, ap = self.calculate_metrics(pos_proba, neg_proba)
|
||||
score = DatasetAttackScoreGanLeaks(self.dataset_name, roc_auc_score=auc, average_precision_score=ap)
|
||||
if generate_plot:
|
||||
self.plot_roc_curve(pos_proba, neg_proba)
|
||||
|
|
@ -119,9 +131,9 @@ class DatasetAttackGanLeaks(DatasetAttackPerRecord):
|
|||
def probability_per_sample(distances: np.ndarray):
|
||||
"""
|
||||
For every sample represented by its distance from the query sample to its KNN in synthetic data,
|
||||
the probability of the synthetic data to be part of the query dataset.
|
||||
computes the probability of the synthetic data to be part of the query dataset.
|
||||
:param distances: distance between every query sample in batch to its KNNs among synthetic samples
|
||||
:return:
|
||||
:return
|
||||
distances: probability estimates of the query samples being generated and so being part of the synthetic set
|
||||
"""
|
||||
return np.average(np.exp(-distances), axis=1)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ This module implements privacy risk assessment of synthetic datasets based on th
|
|||
"Holdout-Based Fidelity and Privacy Assessment of Mixed-Type Synthetic Data" by M. Platzer and T. Reutterer.
|
||||
and on a variation of its reference implementation in https://github.com/mostly-ai/paper-fidelity-accuracy.
|
||||
"""
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -15,8 +14,6 @@ from apt.risk.data_assessment.dataset_attack import DatasetAttackWhole, Config
|
|||
from apt.risk.data_assessment.dataset_attack_result import DatasetAttackScore
|
||||
from apt.utils.datasets import ArrayDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttackHoldoutConfig(Config):
|
||||
|
|
@ -28,7 +25,9 @@ class DatasetAttackHoldoutConfig(Config):
|
|||
batch_size: Query sample batch size.
|
||||
compute_distance: A callable function, which takes two arrays representing 1D vectors as inputs and must return
|
||||
one value indicating the distance between those vectors.
|
||||
batch_size: Additional keyword arguments for the distance computation function.
|
||||
See 'metric' parameter in sklearn.neighbors.NearestNeighbors documentation.
|
||||
distance_params: Additional keyword arguments for the distance computation function, see 'metric_params' in
|
||||
sklearn.neighbors.NearestNeighbors documentation.
|
||||
"""
|
||||
k: int = 1
|
||||
use_batches: bool = False
|
||||
|
|
@ -90,8 +89,10 @@ class DatasetAttackHoldout(DatasetAttackWhole):
|
|||
member_distances, non_member_distances = self.calculate_distances()
|
||||
n_members = len(member_distances)
|
||||
n_non_members = len(non_member_distances)
|
||||
assert (n_members == n_non_members)
|
||||
share = np.mean(member_distances < non_member_distances) + (n_members / (n_members + n_non_members)) * np.mean(
|
||||
assert (n_members == n_non_members) # distance of the synth. records to members and to non-members
|
||||
# percent of synth. records closer to members,
|
||||
# and half those whose distance is similar to members and non-members
|
||||
share = np.mean(member_distances < non_member_distances) + 0.5 * np.mean(
|
||||
member_distances == non_member_distances)
|
||||
score = DatasetAttackScoreHoldout(self.dataset_name, share=share)
|
||||
return score
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ diabetes_dataset_np = get_diabetes_dataset_np()
|
|||
nursery_dataset_pd = get_nursery_dataset_pd()
|
||||
adult_dataset_pd = get_adult_dataset_pd()
|
||||
|
||||
mgr = DatasetAssessmentManager(DatasetAssessmentManagerConfig(persist_reports=True, generate_plots=False))
|
||||
mgr = DatasetAssessmentManager(DatasetAssessmentManagerConfig(persist_reports=False, generate_plots=False))
|
||||
|
||||
|
||||
def teardown_function():
|
||||
|
|
@ -40,28 +40,26 @@ def test_risk_anonymization(name, data, dataset_type, k, mgr):
|
|||
(x_train, y_train), (x_test, y_test) = data
|
||||
|
||||
if dataset_type == 'np':
|
||||
original_data_members = ArrayDataset(x_train, y_train)
|
||||
# no need to preprocess
|
||||
preprocessed_x_train = x_train
|
||||
preprocessed_x_test = x_test
|
||||
QI = [0, 2]
|
||||
anonymizer = Anonymize(k, QI, train_only_QI=True)
|
||||
anonymized_data = ArrayDataset(anonymizer.anonymize(original_data_members))
|
||||
original_data_non_members = ArrayDataset(x_test, y_test)
|
||||
elif "adult" in name:
|
||||
encoded, encoded_test = preprocess_adult_x_data(x_train, x_test)
|
||||
preprocessed_x_train, preprocessed_x_test = preprocess_adult_x_data(x_train, x_test)
|
||||
QI = list(range(15, 27))
|
||||
anonymizer = Anonymize(k, QI)
|
||||
anonymized_data = ArrayDataset(anonymizer.anonymize(ArrayDataset(encoded, y_train)))
|
||||
original_data_members = ArrayDataset(encoded, y_train)
|
||||
original_data_non_members = ArrayDataset(encoded_test, y_test)
|
||||
elif "nursery" in name:
|
||||
encoded, encoded_test = preprocess_nursery_x_data(x_train, x_test)
|
||||
preprocessed_x_train, preprocessed_x_test = preprocess_nursery_x_data(x_train, x_test)
|
||||
QI = list(range(15, 27))
|
||||
anonymizer = Anonymize(k, QI, train_only_QI=True)
|
||||
anonymized_data = ArrayDataset(anonymizer.anonymize(ArrayDataset(encoded, y_train)))
|
||||
original_data_members = ArrayDataset(encoded, y_train)
|
||||
original_data_non_members = ArrayDataset(encoded_test, y_test)
|
||||
else:
|
||||
raise ValueError('Pandas dataset missing a preprocessing step')
|
||||
|
||||
anonymized_data = ArrayDataset(anonymizer.anonymize(ArrayDataset(preprocessed_x_train, y_train)))
|
||||
original_data_members = ArrayDataset(preprocessed_x_train, y_train)
|
||||
original_data_non_members = ArrayDataset(preprocessed_x_test, y_test)
|
||||
|
||||
score_g, score_h = mgr.assess(original_data_members, original_data_non_members, anonymized_data,
|
||||
f'anon_k{k}_{name}')
|
||||
assert (score_g.roc_auc_score > 0.5)
|
||||
|
|
@ -80,29 +78,24 @@ testdata = [('iris_np', iris_dataset_np, 'np', mgr),
|
|||
def test_risk_kde(name, data, dataset_type, mgr):
|
||||
(x_train, y_train), (x_test, y_test) = data
|
||||
|
||||
original_data_members = ArrayDataset(x_train, y_train)
|
||||
original_data_non_members = ArrayDataset(x_test, y_test)
|
||||
|
||||
if dataset_type == 'np':
|
||||
synth_data = ArrayDataset(kde(NUM_SYNTH_SAMPLES, n_components=NUM_SYNTH_COMPONENTS,
|
||||
original_data=original_data_members.get_samples()))
|
||||
encoded = x_train
|
||||
encoded_test = x_test
|
||||
num_synth_components = NUM_SYNTH_COMPONENTS
|
||||
elif "adult" in name:
|
||||
encoded, encoded_test = preprocess_adult_x_data(x_train, x_test)
|
||||
num_synth_components = 10
|
||||
synth_data = ArrayDataset(
|
||||
kde(NUM_SYNTH_SAMPLES, n_components=num_synth_components, original_data=encoded))
|
||||
original_data_members = ArrayDataset(encoded, y_train)
|
||||
original_data_non_members = ArrayDataset(encoded_test, y_test)
|
||||
elif "nursery" in name:
|
||||
encoded, encoded_test = preprocess_nursery_x_data(x_train, x_test)
|
||||
num_synth_components = 10
|
||||
synth_data = ArrayDataset(
|
||||
kde(NUM_SYNTH_SAMPLES, n_components=num_synth_components, original_data=encoded))
|
||||
original_data_members = ArrayDataset(encoded, y_train)
|
||||
original_data_non_members = ArrayDataset(encoded_test, y_test)
|
||||
else:
|
||||
raise ValueError('Pandas dataset missing a preprocessing step')
|
||||
|
||||
synth_data = ArrayDataset(
|
||||
kde(NUM_SYNTH_SAMPLES, n_components=num_synth_components, original_data=encoded))
|
||||
original_data_members = ArrayDataset(encoded, y_train)
|
||||
original_data_non_members = ArrayDataset(encoded_test, y_test)
|
||||
|
||||
score_g, score_h = mgr.assess(original_data_members, original_data_non_members, synth_data,
|
||||
'kde' + str(NUM_SYNTH_SAMPLES) + name)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue