Return a more specific class in calculate_privacy_score(). Add more type hints and comments. Make method static.

Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
This commit is contained in:
Maya Anderson 2023-03-19 15:22:48 +02:00
parent 4c7cad86df
commit 89bc9f0989
3 changed files with 15 additions and 10 deletions

View file

@ -66,7 +66,7 @@ class DatasetAttackMembership(DatasetAttack):
@abc.abstractmethod
def calculate_privacy_score(self, dataset_attack_result: DatasetAttackResultMembership,
generate_plot=False) -> DatasetAttackScore:
generate_plot: bool = False) -> DatasetAttackScore:
"""
Calculate dataset privacy score based on the result of the privacy attack
:return:
@ -74,12 +74,15 @@ class DatasetAttackMembership(DatasetAttack):
"""
pass
def plot_roc_curve(self, member_probabilities, non_member_probabilities, name_prefix=""):
@staticmethod
def plot_roc_curve(dataset_name: str, member_probabilities: np.ndarray, non_member_probabilities: np.ndarray,
filename_prefix: str = ""):
"""
Plot ROC curve
:param dataset_name: dataset name, will become part of the plot filename
:param member_probabilities: probability estimates of the member samples, the training data
:param non_member_probabilities: probability estimates of the non-member samples, the hold-out data
:param name_prefix: name prefix for the ROC curve plot
:param filename_prefix: name prefix for the ROC curve plot
"""
labels = np.concatenate((np.zeros((len(non_member_probabilities),)), np.ones((len(member_probabilities),))))
results = np.concatenate((non_member_probabilities, member_probabilities))
@ -87,10 +90,10 @@ class DatasetAttackMembership(DatasetAttack):
svc_disp.plot()
plt.plot([0, 1], [0, 1], color="navy", linewidth=2, linestyle="--", label='No skills')
plt.title('ROC curve')
plt.savefig(f'{name_prefix}{self.dataset_name}_roc_curve.png')
plt.savefig(f'{filename_prefix}{dataset_name}_roc_curve.png')
@staticmethod
def calculate_metrics(member_probabilities, non_member_probabilities):
def calculate_metrics(member_probabilities: np.ndarray, non_member_probabilities: np.ndarray):
"""
Calculate attack performance metrics
:param member_probabilities: probability estimates of the member samples, the training data

View file

@ -126,7 +126,7 @@ class DatasetAttackMembershipKnnProbabilities(DatasetAttackMembership):
return score
def calculate_privacy_score(self, dataset_attack_result: DatasetAttackResultMembership,
generate_plot=False) -> DatasetAttackScore:
generate_plot: bool = False) -> DatasetAttackScoreMembershipKnnProbabilities:
"""
Evaluate privacy score from the probabilities of member and non-member samples to be generated by the synthetic
data generator. The probabilities are computed by the ``assess_privacy()`` method.
@ -143,7 +143,7 @@ class DatasetAttackMembershipKnnProbabilities(DatasetAttackMembership):
result=dataset_attack_result,
roc_auc_score=auc, average_precision_score=ap)
if generate_plot:
self.plot_roc_curve(member_proba, non_member_proba)
self.plot_roc_curve(self.dataset_name, member_proba, non_member_proba)
return score
@staticmethod
@ -151,8 +151,10 @@ class DatasetAttackMembershipKnnProbabilities(DatasetAttackMembership):
"""
For every sample represented by its distance from the query sample to its KNN in synthetic data,
computes the probability of the synthetic data to be part of the query dataset.
:param distances: distance between every query sample in batch to its KNNs among synthetic samples
:param distances: distance between every query sample in batch to its KNNs among synthetic samples, a numpy
array of size (n, k) with n being the number of samples, k - the number of KNNs
:return:
probability estimates of the query samples being generated and so - of being part of the synthetic set
probability estimates of the query samples being generated and so - of being part of the synthetic set, a
numpy array of size (n,)
"""
return np.average(np.exp(-distances), axis=1)

View file

@ -44,7 +44,7 @@ class DatasetAttackScoreWholeDatasetKnnDistance(DatasetAttackScore):
share: float
assessment_type: str = 'WholeDatasetKnnDistance' # to be used in reports
def __init__(self, dataset_name, share) -> None:
def __init__(self, dataset_name: str, share: float) -> None:
"""
dataset_name: dataset name to be used in reports
share : the share of synthetic records closer to the training than the holdout dataset.