Support 1-hot encoded features in anonymization + fixes related to encoding in minimization (#86)

* Support 1-hot encoded features in anonymization (#72)
* Fix anonymization adult notebook + new notebook to demonstrate anonymization on 1-hot encoded data

* Minimizer: No default encoder, if none provided data is supplied to the model as is. Fix data type of representative values. Fix and add more tests.

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailgold 2023-10-19 11:48:15 +03:00 committed by GitHub
parent 26addd192f
commit 5dce961092
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 670 additions and 255 deletions

View file

@ -1,7 +1,6 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -9,7 +8,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -23,13 +21,72 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/abigailt/Library/Python/3.9/lib/python/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 39 13 2174 0 40]\n",
" [ 50 13 0 0 13]\n",
" [ 38 9 0 0 40]\n",
" ...\n",
" [ 27 13 0 0 40]\n",
" [ 26 11 0 0 48]\n",
" [ 27 9 0 0 40]]\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"import os\n",
"import sys\n",
"sys.path.insert(0, os.path.abspath('..'))\n",
"from apt.utils.dataset_utils import get_adult_dataset_pd\n",
"\n",
"# requires a folder called 'datasets' in the current directory\n",
"(x_train, y_train), (x_test, y_test) = get_adult_dataset_pd()\n",
"x_train = x_train.to_numpy()\n",
"y_train = y_train.to_numpy().astype(int)\n",
"x_test = x_test.to_numpy()\n",
"y_test = y_test.to_numpy().astype(int)\n",
"\n",
"# Use only numeric features (age, education-num, capital-gain, capital-loss, hours-per-week)\n",
"x_train = x_train[:, [0, 2, 8, 9, 10]].astype(int)\n",
"x_test = x_test[:, [0, 2, 8, 9, 10]].astype(int)\n",
"\n",
"# get balanced dataset\n",
"x_train = x_train[:x_test.shape[0]]\n",
"y_train = y_train[:y_test.shape[0]]\n",
"\n",
"print(x_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train decision tree model"
]
},
{
"cell_type": "code",
"execution_count": 3,
@ -39,76 +96,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 39. 13. 2174. 0. 40.]\n",
" [ 50. 13. 0. 0. 13.]\n",
" [ 38. 9. 0. 0. 40.]\n",
" ...\n",
" [ 27. 13. 0. 0. 40.]\n",
" [ 26. 11. 0. 0. 48.]\n",
" [ 27. 9. 0. 0. 40.]]\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"# Use only numeric features (age, education-num, capital-gain, capital-loss, hours-per-week)\n",
"x_train = np.loadtxt(\"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n",
" usecols=(0, 4, 10, 11, 12), delimiter=\", \")\n",
"\n",
"y_train = np.loadtxt(\"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n",
" usecols=14, dtype=str, delimiter=\", \")\n",
"\n",
"\n",
"x_test = np.loadtxt(\"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\",\n",
" usecols=(0, 4, 10, 11, 12), delimiter=\", \", skiprows=1)\n",
"\n",
"y_test = np.loadtxt(\"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\",\n",
" usecols=14, dtype=str, delimiter=\", \", skiprows=1)\n",
"\n",
"# Trim trailing period \".\" from label\n",
"y_test = np.array([a[:-1] for a in y_test])\n",
"\n",
"y_train[y_train == '<=50K'] = 0\n",
"y_train[y_train == '>50K'] = 1\n",
"y_train = y_train.astype(int)\n",
"\n",
"y_test[y_test == '<=50K'] = 0\n",
"y_test[y_test == '>50K'] = 1\n",
"y_test = y_test.astype(int)\n",
"\n",
"# get balanced dataset\n",
"x_train = x_train[:x_test.shape[0]]\n",
"y_train = y_train[:y_test.shape[0]]\n",
"\n",
"print(x_train)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train decision tree model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Base model accuracy: 0.8076285240464345\n"
"Base model accuracy: 0.8087341072415699\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mayaa/Development/GitHub/aiprivacy/ai-privacy-toolkit/venv1/lib/python3.8/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.\n",
"/Users/abigailt/Library/Python/3.9/lib/python/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.\n",
" warnings.warn(msg, category=FutureWarning)\n"
]
}
@ -122,13 +117,10 @@
"\n",
"art_classifier = ScikitlearnDecisionTreeClassifier(model)\n",
"\n",
"print('Base model accuracy: ', model.score(x_test, y_test))\n",
"\n",
"x_train_predictions = np.array([np.argmax(arr) for arr in art_classifier.predict(x_train)]).reshape(-1,1)"
"print('Base model accuracy: ', model.score(x_test, y_test))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -139,7 +131,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -159,7 +151,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -168,14 +159,14 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.5460017196904557\n"
"0.5434836015231544\n"
]
}
],
@ -191,7 +182,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -199,7 +189,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -213,30 +202,29 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[38. 13. 0. 0. 40.]\n",
" [46. 13. 0. 0. 35.]\n",
" [28. 9. 0. 0. 40.]\n",
"[[38 13 0 0 40]\n",
" [46 13 0 0 35]\n",
" [28 9 0 0 40]\n",
" ...\n",
" [26. 13. 0. 0. 40.]\n",
" [27. 10. 0. 0. 50.]\n",
" [28. 9. 0. 0. 40.]]\n"
" [26 13 0 0 40]\n",
" [27 10 0 0 50]\n",
" [28 9 0 0 40]]\n"
]
}
],
"source": [
"import os\n",
"import sys\n",
"sys.path.insert(0, os.path.abspath('..'))\n",
"from apt.utils.datasets import ArrayDataset\n",
"from apt.anonymization import Anonymize\n",
"\n",
"x_train_predictions = np.array([np.argmax(arr) for arr in art_classifier.predict(x_train)])\n",
"\n",
"# QI = (age, education-num, capital-gain, hours-per-week)\n",
"QI = [0, 1, 2, 4]\n",
"anonymizer = Anonymize(100, QI)\n",
@ -246,7 +234,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@ -255,7 +243,7 @@
"6739"
]
},
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -267,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [
{
@ -276,7 +264,7 @@
"401"
]
},
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@ -287,7 +275,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -296,21 +283,21 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Anonymized model accuracy: 0.826914808672686\n"
"Anonymized model accuracy: 0.8308457711442786\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mayaa/Development/GitHub/aiprivacy/ai-privacy-toolkit/venv1/lib/python3.8/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.\n",
"/Users/abigailt/Library/Python/3.9/lib/python/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.\n",
" warnings.warn(msg, category=FutureWarning)\n"
]
}
@ -325,7 +312,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -335,14 +321,14 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.49692912418621793\n"
"0.4944724235351923\n"
]
}
],
@ -364,7 +350,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -380,8 +365,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5316007088009451, 0.7738607050730868)\n",
"(0.4971184877823882, 0.5297874953936863)\n"
"without anonymization: (0.5303914835164835, 0.7588748311018303)\n",
"with anonymization: (0.49255952380952384, 0.3659255619702739)\n"
]
}
],
@ -411,15 +396,14 @@
" return precision, recall\n",
"\n",
"# regular\n",
"print(calc_precision_recall(np.concatenate((inferred_train_bb, inferred_test_bb)), \n",
"print('without anonymization:', calc_precision_recall(np.concatenate((inferred_train_bb, inferred_test_bb)), \n",
" np.concatenate((np.ones(len(inferred_train_bb)), np.zeros(len(inferred_test_bb))))))\n",
"# anon\n",
"print(calc_precision_recall(np.concatenate((anon_inferred_train_bb, anon_inferred_test_bb)), \n",
"print('with anonymization:', calc_precision_recall(np.concatenate((anon_inferred_train_bb, anon_inferred_test_bb)), \n",
" np.concatenate((np.ones(len(anon_inferred_train_bb)), np.zeros(len(anon_inferred_test_bb))))))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -429,7 +413,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -443,7 +427,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.9.6"
}
},
"nbformat": 4,