ai-privacy-toolkit/notebooks/dataset_assessment_nursery.ipynb

403 lines
58 KiB
Text
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Using AI privacy dataset assessment"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"In this tutorial we will show how to perform privacy risk analysis of synthetic datasets for ML models using the dataset assessment module.\n",
"\n",
"This will be demonstrated using the Nursery dataset (original dataset can be found here: https://archive.ics.uci.edu/ml/datasets/nursery).\n",
"\n",
"The method `get_nursery_dataset_pd()` preprocesses the data such that all categorical features are one-hot encoded, and all the features are scaled."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Load data\n",
"Load the nursery dataset with preprocessing and divided into a training and a test (holdout) dataset."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"sys.path.insert(0, os.path.abspath('..'))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from apt.utils.dataset_utils import get_nursery_dataset_pd\n",
"\n",
"(x_train, y_train), (x_test, y_test) = get_nursery_dataset_pd(raw=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### A simplistic synthetic data generator\n",
"We are using here a simple synthetic data generator just for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from sklearn.neighbors import KernelDensity\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.decomposition import PCA\n",
"import numpy as np\n",
"\n",
"\n",
"def kde(n_samples, n_components, original_data):\n",
" \"\"\"\n",
" Simple synthetic data generator: estimates the kernel density of data using a Gaussian kernel and then generates\n",
" samples from this distribution\n",
" \"\"\"\n",
" digit_data = original_data\n",
" pca = PCA(n_components=n_components, whiten=False)\n",
" data = pca.fit_transform(digit_data)\n",
" params = {'bandwidth': np.logspace(-1, 1, 20)}\n",
" grid = GridSearchCV(KernelDensity(), params, cv=5)\n",
" grid.fit(data)\n",
"\n",
" kde_estimator = grid.best_estimator_\n",
"\n",
" new_data = kde_estimator.sample(n_samples, random_state=0)\n",
" new_data = pca.inverse_transform(new_data)\n",
" return new_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Generate synthetic data based on the training data provided using the above simple synthetic data generator."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from apt.utils.datasets import ArrayDataset\n",
"\n",
"NUM_SYNTH_SAMPLES = 1000\n",
"num_synth_components = 4\n",
"synthetic_data = ArrayDataset(\n",
" kde(NUM_SYNTH_SAMPLES, n_components=num_synth_components, original_data=x_train))\n",
"original_data_members = ArrayDataset(x_train, y_train)\n",
"original_data_non_members = ArrayDataset(x_test, y_test)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run dataset assessment attacks using the DatasetAssessmentManager\n",
"Run all the dataset assessment attacks and get all their scores."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from apt.risk.data_assessment.dataset_assessment_manager import DatasetAssessmentManager\n",
"\n",
"mgr = DatasetAssessmentManager()\n",
"[score_g, score_h] = mgr.assess(original_data_members, original_data_non_members, synthetic_data)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"You can look at the detailed scores of all the attacks:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[DatasetAttackScoreMembershipKnnProbabilities(dataset_name='dataset', risk_score=0.5247189047081302, result=DatasetAttackResultMembership(member_probabilities=array([0.01112053, 0.03040544, 0.00952443, ..., 0.0425625 , 0.01733997,\n",
" 0.0203852 ]), non_member_probabilities=array([0.01553551, 0.01538259, 0.01611245, ..., 0.01016964, 0.01561895,\n",
" 0.01174237])), roc_auc_score=0.5247189047081302, average_precision_score=0.8141482366545616, assessment_type='MembershipKnnProbabilities'),\n",
" DatasetAttackScoreWholeDatasetKnnDistance(dataset_name='dataset', risk_score=0.841, result=None, share=0.841, assessment_type='WholeDatasetKnnDistance')]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[score_g, score_h]"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Or you can look at only the privacy risk scores of all the attacks:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.5247189047081302, 0.841]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[score_g.risk_score, score_h.risk_score]"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Run dataset assessment attacks directly"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### DatasetAttackMembershipKnnProbabilities\n",
"Run the privacy risk assessment for synthetic datasets based on Black-Box MIA attack using distances of\n",
"members (training set) and non-members (holdout set) from their nearest neighbors in the synthetic dataset.\n",
"The area under the receiver operating characteristic curve (AUC ROC) gives the privacy risk measure.\n",
"The ROC curve is displayed and saved in a file `nursery_kde_roc_curve.png`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1036/1036 [00:11<00:00, 87.53it/s] \n",
"100%|██████████| 259/259 [00:02<00:00, 109.32it/s]\n"
]
},
{
"data": {
"text/plain": [
"DatasetAttackScoreMembershipKnnProbabilities(dataset_name='nursery_kde', risk_score=0.5246348071734173, result=DatasetAttackResultMembership(member_probabilities=array([0.01112053, 0.03040544, 0.00952443, ..., 0.01370366, 0.03162697,\n",
" 0.02039033]), non_member_probabilities=array([0.01553551, 0.01538259, 0.01611245, ..., 0.02506744, 0.02278329,\n",
" 0.01016964])), roc_auc_score=0.5246348071734173, average_precision_score=0.8140989865974944, assessment_type='MembershipKnnProbabilities')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACKF0lEQVR4nOzddVxU2fsH8M/QoISKqCAKdiuKotiKYrdgo67duXbnfl1r1dV1DWywdQ3swC6wUEzEAAyke+b8/vDn1VlyWIb8vF8vXst97jl3nrmLzMO9554jE0IIEBEREeUSGlmdABEREVFGYnFDREREuQqLGyIiIspVWNwQERFRrsLihoiIiHIVFjdERESUq7C4ISIiolyFxQ0RERHlKixuiIiIKFdhcUNERES5CosbIkqVq6srZDKZ9KWlpQULCwv0798f79+/T7KPEAI7duxAo0aNYGJiAgMDA1StWhXz589HZGRksq916NAhtG7dGqamptDR0YG5uTmcnJxw/vx5db09IsplZFxbiohS4+rqigEDBmD+/PmwtrZGTEwMbty4AVdXV1hZWeHRo0fQ09OT2svlcvTq1Qt79+5Fw4YN0aVLFxgYGMDT0xO7d+9GpUqVcPbsWRQpUkTqI4TAwIED4erqChsbG3Tr1g1FixZFQEAADh06hLt37+Lq1auwt7fPilNARDmJICJKxdatWwUAcfv2baX4lClTBADh7u6uFF+8eLEAICZNmpToWEePHhUaGhqiVatWSvFly5YJAGLcuHFCoVAk6rd9+3Zx8+bNDHg36RcREZGlr09EacPbUkSUbg0bNgQAvHz5UopFR0dj2bJlKFeuHJYsWZKoT/v27eHi4gIPDw/cuHFD6rNkyRJUqFABv//+O2QyWaJ+ffv2RZ06dVLMR6FQYPXq1ahatSr09PRQuHBhtGrVCnfu3AEA+Pn5QSaTwdXVNVFfmUyGuXPnSttz586FTCaDj48PevXqhQIFCqBBgwZSfm/evEl0jGnTpkFHRwdfv36VYjdv3kSrVq1gbGwMAwMDNG7cGFevXk3xfRDRf8PihojSzc/PDwBQoEABKXblyhV8/foVvXr1gpaWVpL9+vXrBwA4duyY1Cc4OBi9evWCpqZmuvP55ZdfMG7cOFhaWuK3337D1KlToaenJxVR6dG9e3dERUVh8eLFGDx4MJycnCCTybB3795Ebffu3YuWLVtK5+P8+fNo1KgRwsLCMGfOHCxevBghISFo1qwZbt26le6ciChlSf/mISJKQmhoKD5//oyYmBjcvHkT8+bNg66uLtq1aye18fHxAQBUr1492eN83/fkyROl/1atWjXduV24cAGurq4YM2YMVq9eLcUnTpwI8R+GFlavXh27d+9WitWtWxfu7u6YPHmyFLt9+zZevXolXf0RQmDYsGFo2rQpTp48KV2NGjp0KCpXroyZM2fi9OnT6c6LiJLHKzdElGYODg4oXLgwLC0t0a1bN+TLlw9Hjx5F8eLFpTbh4eEAAENDw2SP831fWFiY0n9T6pOaAwcOQCaTYc6cOYn2JXWbK62GDRuWKObs7Iy7d+8q3Y5zd3eHrq4uOnbsCADw9vbG8+fP0atXL3z58gWfP3/G58+fERkZiebNm+Py5ctQKBTpzouIksfihojSbN26dThz5gz279+PNm3a4PPnz9DV1VVq871A+V7kJOXfBZCRkVGqfVLz8uVLmJubo2DBguk+RlKsra0Txbp37w4NDQ24u7sD+HaVZt++fWjdurX0Xp4/fw4AcHFxQeHChZW+Nm3ahNjYWISGhmZorkT0DW9LEVGa1alTB7a2tgCATp06oUGDBujVqxd8fX2RP39+AEDFihUBAA8ePECnTp2SPM6DBw8AAJUqVQIAVKhQAQDw8OHDZPtkhOSu4Mjl8mT76OvrJ4qZm5ujYcOG2Lt3L6ZPn44bN27A398fv/32m9Tm+1WZZcuWoUaNGkke+/s5I6KMxSs3RJQumpqaWLJkCT58+IC1a9dK8QYNGsDExAS7d+9OtmjYvn07AEhjdRo0aIACBQpgz549KRYaKSldujQ+fPiA4ODgZNt8H+gbEhKiFE/qyafUODs74/79+/D19YW7uzsMDAzQvn17pXyAb1elHBwckvzS1tZW+XWJKHUsbogo3Zo0aYI6depg1apViImJAQAYGBhg0qRJ8PX1xYwZMxL1OX78OFxdXeHo6Ii6detKfaZMmYInT55gypQpSQ4A3rlzZ4pPGHXt2hVCCMybNy/Rvu/HMzIygqmpKS5fvqy0/88//0z7m/7p9TQ1NbFnzx7s27cP7dq1Q758+aT9tWrVQunSpfH7778jIiIiUf9Pnz6p/JpElDa8LUVE/8nkyZPRvXt3uLq6SoNvp06dCi8vL/z222+4fv06unbtCn19fVy5cgU7d+5ExYoVsW3btkTHefz4MZYvX44LFy5IMxQHBgbi8OHDuHXrFq5du5ZsHk2bNkXfvn3xxx9/4Pnz52jVqhUUCgU8PT3RtGlTjBo1CgAwaNAgLF26FIMGDYKtrS0uX76MZ8+eqfy+zczM0LRpU6xYsQLh4eFwdnZW2q+hoYFNmzahdevWqFy5MgYMGAALCwu8f/8eFy5cgJGREf755x+VX5eI0iArZxAkopwhuRmKhRBCLpeL0qVLi9KlS4uEhASl+NatW0X9+vWFkZGR0NPTE5UrVxbz5s1Lcabf/fv3i5YtW4qCBQsKLS0tUaxYMeHs7CwuXryYap4JCQli2bJlokKFCkJHR0cULlxYtG7dWty9e1dqExUVJX755RdhbGwsDA0NhZOTk/j48aMAIObMmSO1mzNnjgAgPn36lOzr/f333wKAMDQ0FNHR0Um28fLyEl26dBGFChUSurq6omTJksLJyUmcO3cu1fdDROnDtaWIiIgoV+GYGyIiIspVWNwQERFRrsLihoiIiHIVFjdERESUq7C4ISIiolyFxQ0RERHlKnluEj+FQoEPHz7A0NDwP60UTERERJlHCIHw8HCYm5tDQyPlazN5rrj58OEDLC0tszoNIiIiSoe3b9+iePHiKbbJc8WNoaEhgG8nx8jIKIuzISIiorQICwuDpaWl9DmekjxX3Hy/FWVkZMTihoiIKIdJy5ASDigmIiKiXIXFDREREeUqLG6IiIgoV2FxQ0RERLkKixsiIiLKVVjcEBERUa7C4oaIiIhyFRY3RERElKuwuCEiIqJchcUNERER5SpZWtxcvnwZ7du3h7m5OWQyGQ4fPpxqn4sXL6JmzZrQ1dVFmTJl4OrqqvY8iYiIKOfI0uImMjIS1atXx7p169LU/vXr12jbti2aNm0Kb29vjBs3DoMGDcKpU6fUnCkRERHlFFm6cGbr1q3RunXrNLffsGEDrK2tsXz5cgBAxYoVceXKFaxcuRKOjo7qSpOIiIjS4EtIDCLlcujpaMLMUC/L8shRY26uX78OBwcHpZijoyOuX7+ebJ/Y2FiEhYUpfREREVHGSZArsGy3N4qWXgWbnnsxbMfdLM0nRxU3gYGBKFKkiFKsSJEiCAsLQ3R0dJJ9lixZAmNjY+nL0tIyM1IlIiLK9YQQmHHoIUoMOYBf+x1FQnAsgs+/xYP7QVmaV44qbtJj2rRpCA0Nlb7evn2b1SkRERHleHf8gmE97QR23fSHtqk+9K2NAQCFLY1wYmLjLM0tS8fcqKpo0aIIClKuBoOCgmBkZAR9ff0k++jq6kJXVzcz0iMiIsr1AkKjUW/JeaWYTCbDuSM9cWLPI8yc2Qi6ullbXuSo4qZevXo4ceKEUuzMmTOoV69eFmVERESU+4XFxOPXfQ/g8TgQQgiE3/sI7UL60LcyQhcbCyx3qg6ZTIZ6C5pldaoAsri4iYiIwIsXL6Tt169fw9vbGwULFkSJEiUwbdo0vH//Htu3bwcADBs2DGvXrsWvv/6KgQMH4vz589i7dy+OHz+eVW+BiIgo15r/jw+2XH0tbctjEvDlpB+in4VAJ7827vmOhoW5YRZmmLQsHXNz584d2NjYwMbGBgAwYcIE2NjYYPbs2QCAgIAA+Pv7S+2tra1x/PhxnDlzBtWrV8fy5cu
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from apt.risk.data_assessment.dataset_attack_membership_knn_probabilities import \\\n",
" DatasetAttackConfigMembershipKnnProbabilities, DatasetAttackMembershipKnnProbabilities\n",
"\n",
"dataset_name = \"nursery_kde\"\n",
"\n",
"config_g = DatasetAttackConfigMembershipKnnProbabilities(use_batches=True,\n",
" generate_plot=True)\n",
"attack_g = DatasetAttackMembershipKnnProbabilities(original_data_members,\n",
" original_data_non_members,\n",
" synthetic_data,\n",
" config_g,\n",
" dataset_name)\n",
"\n",
"score_g = attack_g.assess_privacy()\n",
"score_g"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### DatasetAttackWholeDatasetKnnDistance\n",
"Run the privacy risk assessment for synthetic datasets based on distances of synthetic data records from\n",
"members (training set) and non-members (holdout set). \n",
"\n",
"The privacy risk measure is the share of synthetic\n",
"records closer to the training than the holdout dataset."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": [
"DatasetAttackScoreWholeDatasetKnnDistance(dataset_name='nursery_kde', risk_score=0.841, result=None, share=0.841, assessment_type='WholeDatasetKnnDistance')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from apt.risk.data_assessment.dataset_attack_whole_dataset_knn_distance import \\\n",
" DatasetAttackConfigWholeDatasetKnnDistance, DatasetAttackWholeDatasetKnnDistance\n",
" \n",
"config_h = DatasetAttackConfigWholeDatasetKnnDistance(use_batches=False)\n",
"attack_h = DatasetAttackWholeDatasetKnnDistance(original_data_members, original_data_non_members,\n",
" synthetic_data, config_h, dataset_name)\n",
"\n",
"score_h = attack_h.assess_privacy()\n",
"score_h"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv1",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "a7b572376dda99aaa0cfb20ab0ebad1d786e8d83835a737650854479888cdec3"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}