Return a more specific class in calculate_privacy_score(). Add more type hints and comments. Make method static.

Signed-off-by: Maya Anderson <mayaa@il.ibm.com>
This commit is contained in:
Maya Anderson 2023-03-19 15:22:48 +02:00
parent 4c7cad86df
commit 89bc9f0989
3 changed files with 15 additions and 10 deletions

View file

@ -66,7 +66,7 @@ class DatasetAttackMembership(DatasetAttack):
@abc.abstractmethod
def calculate_privacy_score(self, dataset_attack_result: DatasetAttackResultMembership,
generate_plot=False) -> DatasetAttackScore:
generate_plot: bool = False) -> DatasetAttackScore:
"""
Calculate dataset privacy score based on the result of the privacy attack
:return:
@ -74,12 +74,15 @@ class DatasetAttackMembership(DatasetAttack):
"""
pass
def plot_roc_curve(self, member_probabilities, non_member_probabilities, name_prefix=""):
@staticmethod
def plot_roc_curve(dataset_name: str, member_probabilities: np.ndarray, non_member_probabilities: np.ndarray,
filename_prefix: str = ""):
"""
Plot ROC curve
:param dataset_name: dataset name, will become part of the plot filename
:param member_probabilities: probability estimates of the member samples, the training data
:param non_member_probabilities: probability estimates of the non-member samples, the hold-out data
:param name_prefix: name prefix for the ROC curve plot
:param filename_prefix: name prefix for the ROC curve plot
"""
labels = np.concatenate((np.zeros((len(non_member_probabilities),)), np.ones((len(member_probabilities),))))
results = np.concatenate((non_member_probabilities, member_probabilities))
@ -87,10 +90,10 @@ class DatasetAttackMembership(DatasetAttack):
svc_disp.plot()
plt.plot([0, 1], [0, 1], color="navy", linewidth=2, linestyle="--", label='No skills')
plt.title('ROC curve')
plt.savefig(f'{name_prefix}{self.dataset_name}_roc_curve.png')
plt.savefig(f'{filename_prefix}{dataset_name}_roc_curve.png')
@staticmethod
def calculate_metrics(member_probabilities, non_member_probabilities):
def calculate_metrics(member_probabilities: np.ndarray, non_member_probabilities: np.ndarray):
"""
Calculate attack performance metrics
:param member_probabilities: probability estimates of the member samples, the training data