ai-privacy-toolkit/apt/risk/data_assessment/dataset_attack.py
andersonm-ibm a40484e0c9
Add column distribution comparison, and a third method for dataset asssessment by membership classification (#84)
* Add column distribution comparison, and a third method for dataset assessment by membership classification

* Address review comments, add additional distribution comparison tests and make them externally configurable too, in addition to the alpha becoming configurable.

Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
2023-09-21 16:43:19 +03:00

121 lines
5.3 KiB
Python

"""
This module defines the interface for privacy risk assessment of synthetic datasets.
"""
import abc
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
from sklearn import metrics
from sklearn.metrics import RocCurveDisplay
from apt.risk.data_assessment.attack_strategy_utils import AttackStrategyUtils
from apt.risk.data_assessment.dataset_attack_result import DatasetAttackScore, DatasetAttackResultMembership
from apt.utils.datasets import ArrayDataset
class Config(abc.ABC):
"""
The base class for dataset attack configurations
"""
pass
class DatasetAttack(abc.ABC):
"""
The interface for performing privacy attack for risk assessment of synthetic datasets to be used in AI model
training. The original data members (training data) and non-members (the holdout data) should be available.
For reliability, all the datasets should be preprocessed and normalized.
"""
def __init__(self, original_data_members: ArrayDataset, original_data_non_members: ArrayDataset,
synthetic_data: ArrayDataset, config: Config, dataset_name: str,
categorical_features: list = [],
attack_strategy_utils: Optional[AttackStrategyUtils] = None) -> None:
"""
:param original_data_members: A container for the training original samples and labels,
only samples are used in the assessment
:param original_data_non_members: A container for the holdout original samples and labels,
only samples are used in the assessment
:param synthetic_data: A container for the synthetic samples and labels, only samples are used in the assessment
:param config: Configuration parameters to guide the assessment process
:param dataset_name: A name to identify the dataset under attack, optional
:param categorical_features: The list of categorical features (column names for pandas and column indexes for
numpy), optional
:param attack_strategy_utils: Utils for use with the attack strategy, optional
"""
self.original_data_members = original_data_members
self.original_data_non_members = original_data_non_members
self.synthetic_data = synthetic_data
self.config = config
self.dataset_name = dataset_name
self.categorical_features = categorical_features
self.attack_strategy_utils = attack_strategy_utils
@abc.abstractmethod
def assess_privacy(self) -> DatasetAttackScore:
"""
Assess the privacy of the dataset
:return:
score: DatasetAttackScore the privacy attack risk score
"""
pass
@property
@abc.abstractmethod
def short_name(self):
pass
class DatasetAttackMembership(DatasetAttack):
"""
An abstract base class for performing privacy risk assessment for synthetic datasets on a per-record level.
"""
@abc.abstractmethod
def calculate_privacy_score(self, dataset_attack_result: DatasetAttackResultMembership,
generate_plot: bool = False) -> DatasetAttackScore:
"""
Calculate dataset privacy score based on the result of the privacy attack
:return:
score: DatasetAttackScore
"""
pass
@staticmethod
def plot_roc_curve(dataset_name: str, member_probabilities: np.ndarray, non_member_probabilities: np.ndarray,
filename_prefix: str = ""):
"""
Plot ROC curve
:param dataset_name: dataset name, will become part of the plot filename
:param member_probabilities: probability estimates of the member samples, the training data
:param non_member_probabilities: probability estimates of the non-member samples, the hold-out data
:param filename_prefix: name prefix for the ROC curve plot
"""
labels = np.concatenate((np.zeros((len(non_member_probabilities),)), np.ones((len(member_probabilities),))))
results = np.concatenate((non_member_probabilities, member_probabilities))
RocCurveDisplay.from_predictions(labels, results)
plt.plot([0, 1], [0, 1], color="navy", linewidth=2, linestyle="--", label='No skills')
plt.title('ROC curve')
plt.savefig(f'{filename_prefix}{dataset_name}_roc_curve.png')
@staticmethod
def calculate_metrics(member_probabilities: np.ndarray, non_member_probabilities: np.ndarray):
"""
Calculate attack performance metrics
:param member_probabilities: probability estimates of the member samples, the training data
:param non_member_probabilities: probability estimates of the non-member samples, the hold-out data
:return:
fpr: False Positive rate
tpr: True Positive rate
threshold: threshold
auc: area under the Receiver Operating Characteristic Curve
ap: average precision score
"""
labels = np.concatenate((np.zeros((len(non_member_probabilities),)), np.ones((len(member_probabilities)))))
results = np.concatenate((non_member_probabilities, member_probabilities))
fpr, tpr, threshold = metrics.roc_curve(labels, results, pos_label=1)
auc = metrics.roc_auc_score(labels, results)
ap = metrics.average_precision_score(labels, results)
return fpr, tpr, threshold, auc, ap