mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-15 06:52:37 +02:00
New model wrappers (#32)
* keras wrapper + blackbox classifier wrapper (fix #7) * fix error in NCP calculation * Update notebooks * Fix #25 (incorrect attack_feature indexes for social feature in notebook) * Consistent naming of internal parameters
This commit is contained in:
parent
fd6be8e778
commit
fe676fa426
15 changed files with 1407 additions and 656 deletions
|
|
@ -27,7 +27,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
|
@ -42,18 +42,6 @@
|
|||
" [2.2000e+01 9.0000e+00 0.0000e+00 0.0000e+00 2.0000e+01]\n",
|
||||
" [5.2000e+01 9.0000e+00 1.5024e+04 0.0000e+00 4.0000e+01]]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/var/folders/9b/qbtw28w53355cvpjs4qn83yc0000gn/T/ipykernel_13726/1357868359.py:22: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
|
||||
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
|
||||
" y_train = y_train.astype(np.int)\n",
|
||||
"/var/folders/9b/qbtw28w53355cvpjs4qn83yc0000gn/T/ipykernel_13726/1357868359.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
|
||||
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
|
||||
" y_test = y_test.astype(np.int)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
|
|
@ -96,24 +84,28 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Base model accuracy: 0.8183158282660771\n"
|
||||
"Base model accuracy: 0.8190528837295007\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, os.path.abspath('..'))\n",
|
||||
"\n",
|
||||
"from apt.utils.datasets import ArrayDataset\n",
|
||||
"from apt.utils.models import SklearnClassifier, ModelOutputType\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"\n",
|
||||
"base_est = DecisionTreeClassifier()\n",
|
||||
"model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_VECTOR)\n",
|
||||
"model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)\n",
|
||||
"model.fit(ArrayDataset(x_train, y_train))\n",
|
||||
"\n",
|
||||
"print('Base model accuracy: ', model.score(ArrayDataset(x_test, y_test)))"
|
||||
|
|
@ -129,34 +121,30 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Initial accuracy of model on generalized data, relative to original model predictions (base generalization derived from tree, before improvements): 0.936540\n",
|
||||
"Initial accuracy of model on generalized data, relative to original model predictions (base generalization derived from tree, before improvements): 0.920665\n",
|
||||
"Improving accuracy\n",
|
||||
"feature to remove: 2\n",
|
||||
"Removed feature: 2, new relative accuracy: 0.935261\n",
|
||||
"feature to remove: 4\n",
|
||||
"Removed feature: 4, new relative accuracy: 0.946776\n",
|
||||
"feature to remove: 0\n",
|
||||
"Removed feature: 0, new relative accuracy: 0.972876\n",
|
||||
"feature to remove: 1\n",
|
||||
"Removed feature: 1, new relative accuracy: 0.992835\n",
|
||||
"Removed feature: 1, new relative accuracy: 0.920026\n",
|
||||
"feature to remove: 0\n",
|
||||
"Removed feature: 0, new relative accuracy: 0.938580\n",
|
||||
"feature to remove: 4\n",
|
||||
"Removed feature: 4, new relative accuracy: 0.987204\n",
|
||||
"feature to remove: 2\n",
|
||||
"Removed feature: 2, new relative accuracy: 0.992962\n",
|
||||
"feature to remove: 3\n",
|
||||
"Removed feature: 3, new relative accuracy: 1.000000\n",
|
||||
"Accuracy on minimized data: 0.8231229847996315\n"
|
||||
"Accuracy on minimized data: 0.8165771297006907\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"sys.path.insert(0, os.path.abspath('..'))\n",
|
||||
"\n",
|
||||
"from apt.minimization import GeneralizeToRepresentative\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"\n",
|
||||
|
|
@ -169,7 +157,7 @@
|
|||
"# Don't forget to leave a hold-out set for final validation!\n",
|
||||
"X_generalizer_train, x_test, y_generalizer_train, y_test = train_test_split(x_test, y_test, stratify=y_test,\n",
|
||||
" test_size = 0.4, random_state = 38)\n",
|
||||
"x_train_predictions = model.predict(X_generalizer_train)\n",
|
||||
"x_train_predictions = model.predict(ArrayDataset(X_generalizer_train))\n",
|
||||
"if x_train_predictions.shape[1] > 1:\n",
|
||||
" x_train_predictions = np.argmax(x_train_predictions, axis=1)\n",
|
||||
"minimizer.fit(dataset=ArrayDataset(X_generalizer_train, x_train_predictions))\n",
|
||||
|
|
@ -187,14 +175,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'ranges': {}, 'categories': {}, 'untouched': ['4', '1', '3', '0', '2']}\n"
|
||||
"{'ranges': {}, 'categories': {}, 'untouched': ['2', '4', '3', '1', '0']}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
@ -214,25 +202,25 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Initial accuracy of model on generalized data, relative to original model predictions (base generalization derived from tree, before improvements): 0.936540\n",
|
||||
"Initial accuracy of model on generalized data, relative to original model predictions (base generalization derived from tree, before improvements): 0.920665\n",
|
||||
"Improving accuracy\n",
|
||||
"feature to remove: 2\n",
|
||||
"Removed feature: 2, new relative accuracy: 0.935261\n",
|
||||
"feature to remove: 4\n",
|
||||
"Removed feature: 4, new relative accuracy: 0.946776\n",
|
||||
"feature to remove: 0\n",
|
||||
"Removed feature: 0, new relative accuracy: 0.972876\n",
|
||||
"feature to remove: 1\n",
|
||||
"Removed feature: 1, new relative accuracy: 0.992835\n",
|
||||
"Accuracy on minimized data: 0.8192845079072624\n",
|
||||
"{'ranges': {'3': [569.0, 782.0, 870.0, 870.5, 938.0, 1016.5, 1311.5, 1457.0, 1494.5, 1596.0, 1629.5, 1684.0, 1805.0, 1859.0, 1867.5, 1881.5, 1938.0, 1978.5, 2119.0, 2210.0, 2218.0, 2244.5, 2298.5, 2443.5]}, 'categories': {}, 'untouched': ['2', '1', '0', '4']}\n"
|
||||
"Removed feature: 1, new relative accuracy: 0.920026\n",
|
||||
"feature to remove: 0\n",
|
||||
"Removed feature: 0, new relative accuracy: 0.938580\n",
|
||||
"feature to remove: 4\n",
|
||||
"Removed feature: 4, new relative accuracy: 0.987204\n",
|
||||
"feature to remove: 2\n",
|
||||
"Removed feature: 2, new relative accuracy: 0.992962\n",
|
||||
"Accuracy on minimized data: 0.8100537221795856\n",
|
||||
"{'ranges': {'3': [704.0, 782.0, 870.0, 951.0, 1588.0, 1647.5, 1684.0, 1805.0, 1923.0, 2168.5]}, 'categories': {}, 'untouched': ['2', '4', '1', '0']}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
@ -276,4 +264,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue