diff --git a/apt/risk/data_assessment/dataset_attack.py b/apt/risk/data_assessment/dataset_attack.py index 255cffc..4cac42d 100644 --- a/apt/risk/data_assessment/dataset_attack.py +++ b/apt/risk/data_assessment/dataset_attack.py @@ -66,7 +66,7 @@ class DatasetAttackMembership(DatasetAttack): @abc.abstractmethod def calculate_privacy_score(self, dataset_attack_result: DatasetAttackResultMembership, - generate_plot=False) -> DatasetAttackScore: + generate_plot: bool = False) -> DatasetAttackScore: """ Calculate dataset privacy score based on the result of the privacy attack :return: @@ -74,12 +74,15 @@ class DatasetAttackMembership(DatasetAttack): """ pass - def plot_roc_curve(self, member_probabilities, non_member_probabilities, name_prefix=""): + @staticmethod + def plot_roc_curve(dataset_name: str, member_probabilities: np.ndarray, non_member_probabilities: np.ndarray, + filename_prefix: str = ""): """ Plot ROC curve + :param dataset_name: dataset name, will become part of the plot filename :param member_probabilities: probability estimates of the member samples, the training data :param non_member_probabilities: probability estimates of the non-member samples, the hold-out data - :param name_prefix: name prefix for the ROC curve plot + :param filename_prefix: name prefix for the ROC curve plot """ labels = np.concatenate((np.zeros((len(non_member_probabilities),)), np.ones((len(member_probabilities),)))) results = np.concatenate((non_member_probabilities, member_probabilities)) @@ -87,10 +90,10 @@ class DatasetAttackMembership(DatasetAttack): svc_disp.plot() plt.plot([0, 1], [0, 1], color="navy", linewidth=2, linestyle="--", label='No skills') plt.title('ROC curve') - plt.savefig(f'{name_prefix}{self.dataset_name}_roc_curve.png') + plt.savefig(f'{filename_prefix}{dataset_name}_roc_curve.png') @staticmethod - def calculate_metrics(member_probabilities, non_member_probabilities): + def calculate_metrics(member_probabilities: np.ndarray, non_member_probabilities: np.ndarray): """ Calculate attack performance metrics :param member_probabilities: probability estimates of the member samples, the training data diff --git a/apt/risk/data_assessment/dataset_attack_membership_knn_probabilities.py b/apt/risk/data_assessment/dataset_attack_membership_knn_probabilities.py index 6bc99b0..7779b17 100644 --- a/apt/risk/data_assessment/dataset_attack_membership_knn_probabilities.py +++ b/apt/risk/data_assessment/dataset_attack_membership_knn_probabilities.py @@ -126,7 +126,7 @@ class DatasetAttackMembershipKnnProbabilities(DatasetAttackMembership): return score def calculate_privacy_score(self, dataset_attack_result: DatasetAttackResultMembership, - generate_plot=False) -> DatasetAttackScore: + generate_plot: bool = False) -> DatasetAttackScoreMembershipKnnProbabilities: """ Evaluate privacy score from the probabilities of member and non-member samples to be generated by the synthetic data generator. The probabilities are computed by the ``assess_privacy()`` method. @@ -143,7 +143,7 @@ class DatasetAttackMembershipKnnProbabilities(DatasetAttackMembership): result=dataset_attack_result, roc_auc_score=auc, average_precision_score=ap) if generate_plot: - self.plot_roc_curve(member_proba, non_member_proba) + self.plot_roc_curve(self.dataset_name, member_proba, non_member_proba) return score @staticmethod @@ -151,8 +151,10 @@ class DatasetAttackMembershipKnnProbabilities(DatasetAttackMembership): """ For every sample represented by its distance from the query sample to its KNN in synthetic data, computes the probability of the synthetic data to be part of the query dataset. - :param distances: distance between every query sample in batch to its KNNs among synthetic samples + :param distances: distance between every query sample in batch to its KNNs among synthetic samples, a numpy + array of size (n, k) with n being the number of samples, k - the number of KNNs :return: - probability estimates of the query samples being generated and so - of being part of the synthetic set + probability estimates of the query samples being generated and so - of being part of the synthetic set, a + numpy array of size (n,) """ return np.average(np.exp(-distances), axis=1) diff --git a/apt/risk/data_assessment/dataset_attack_whole_dataset_knn_distance.py b/apt/risk/data_assessment/dataset_attack_whole_dataset_knn_distance.py index a57ddf1..1a57bbd 100644 --- a/apt/risk/data_assessment/dataset_attack_whole_dataset_knn_distance.py +++ b/apt/risk/data_assessment/dataset_attack_whole_dataset_knn_distance.py @@ -44,7 +44,7 @@ class DatasetAttackScoreWholeDatasetKnnDistance(DatasetAttackScore): share: float assessment_type: str = 'WholeDatasetKnnDistance' # to be used in reports - def __init__(self, dataset_name, share) -> None: + def __init__(self, dataset_name: str, share: float) -> None: """ dataset_name: dataset name to be used in reports share : the share of synthetic records closer to the training than the holdout dataset.