New model wrappers (#32)

* keras wrapper + blackbox classifier wrapper (fix #7)

* fix error in NCP calculation

* Update notebooks

* Fix #25 (incorrect attack_feature indexes for social feature in notebook)

* Consistent naming of internal parameters
This commit is contained in:
abigailgold 2022-05-12 15:44:29 +03:00 committed by GitHub
parent fd6be8e778
commit fe676fa426
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1407 additions and 656 deletions

View file

@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 1,
"metadata": {},
"outputs": [
{
@ -42,18 +42,6 @@
" [2.2000e+01 9.0000e+00 0.0000e+00 0.0000e+00 2.0000e+01]\n",
" [5.2000e+01 9.0000e+00 1.5024e+04 0.0000e+00 4.0000e+01]]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/9b/qbtw28w53355cvpjs4qn83yc0000gn/T/ipykernel_13726/1357868359.py:22: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" y_train = y_train.astype(np.int)\n",
"/var/folders/9b/qbtw28w53355cvpjs4qn83yc0000gn/T/ipykernel_13726/1357868359.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" y_test = y_test.astype(np.int)\n"
]
}
],
"source": [
@ -96,24 +84,28 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Base model accuracy: 0.8183158282660771\n"
"Base model accuracy: 0.8190528837295007\n"
]
}
],
"source": [
"import os\n",
"import sys\n",
"sys.path.insert(0, os.path.abspath('..'))\n",
"\n",
"from apt.utils.datasets import ArrayDataset\n",
"from apt.utils.models import SklearnClassifier, ModelOutputType\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"base_est = DecisionTreeClassifier()\n",
"model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_VECTOR)\n",
"model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)\n",
"model.fit(ArrayDataset(x_train, y_train))\n",
"\n",
"print('Base model accuracy: ', model.score(ArrayDataset(x_test, y_test)))"
@ -129,34 +121,30 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initial accuracy of model on generalized data, relative to original model predictions (base generalization derived from tree, before improvements): 0.936540\n",
"Initial accuracy of model on generalized data, relative to original model predictions (base generalization derived from tree, before improvements): 0.920665\n",
"Improving accuracy\n",
"feature to remove: 2\n",
"Removed feature: 2, new relative accuracy: 0.935261\n",
"feature to remove: 4\n",
"Removed feature: 4, new relative accuracy: 0.946776\n",
"feature to remove: 0\n",
"Removed feature: 0, new relative accuracy: 0.972876\n",
"feature to remove: 1\n",
"Removed feature: 1, new relative accuracy: 0.992835\n",
"Removed feature: 1, new relative accuracy: 0.920026\n",
"feature to remove: 0\n",
"Removed feature: 0, new relative accuracy: 0.938580\n",
"feature to remove: 4\n",
"Removed feature: 4, new relative accuracy: 0.987204\n",
"feature to remove: 2\n",
"Removed feature: 2, new relative accuracy: 0.992962\n",
"feature to remove: 3\n",
"Removed feature: 3, new relative accuracy: 1.000000\n",
"Accuracy on minimized data: 0.8231229847996315\n"
"Accuracy on minimized data: 0.8165771297006907\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"sys.path.insert(0, os.path.abspath('..'))\n",
"\n",
"from apt.minimization import GeneralizeToRepresentative\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
@ -169,7 +157,7 @@
"# Don't forget to leave a hold-out set for final validation!\n",
"X_generalizer_train, x_test, y_generalizer_train, y_test = train_test_split(x_test, y_test, stratify=y_test,\n",
" test_size = 0.4, random_state = 38)\n",
"x_train_predictions = model.predict(X_generalizer_train)\n",
"x_train_predictions = model.predict(ArrayDataset(X_generalizer_train))\n",
"if x_train_predictions.shape[1] > 1:\n",
" x_train_predictions = np.argmax(x_train_predictions, axis=1)\n",
"minimizer.fit(dataset=ArrayDataset(X_generalizer_train, x_train_predictions))\n",
@ -187,14 +175,14 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'ranges': {}, 'categories': {}, 'untouched': ['4', '1', '3', '0', '2']}\n"
"{'ranges': {}, 'categories': {}, 'untouched': ['2', '4', '3', '1', '0']}\n"
]
}
],
@ -214,25 +202,25 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initial accuracy of model on generalized data, relative to original model predictions (base generalization derived from tree, before improvements): 0.936540\n",
"Initial accuracy of model on generalized data, relative to original model predictions (base generalization derived from tree, before improvements): 0.920665\n",
"Improving accuracy\n",
"feature to remove: 2\n",
"Removed feature: 2, new relative accuracy: 0.935261\n",
"feature to remove: 4\n",
"Removed feature: 4, new relative accuracy: 0.946776\n",
"feature to remove: 0\n",
"Removed feature: 0, new relative accuracy: 0.972876\n",
"feature to remove: 1\n",
"Removed feature: 1, new relative accuracy: 0.992835\n",
"Accuracy on minimized data: 0.8192845079072624\n",
"{'ranges': {'3': [569.0, 782.0, 870.0, 870.5, 938.0, 1016.5, 1311.5, 1457.0, 1494.5, 1596.0, 1629.5, 1684.0, 1805.0, 1859.0, 1867.5, 1881.5, 1938.0, 1978.5, 2119.0, 2210.0, 2218.0, 2244.5, 2298.5, 2443.5]}, 'categories': {}, 'untouched': ['2', '1', '0', '4']}\n"
"Removed feature: 1, new relative accuracy: 0.920026\n",
"feature to remove: 0\n",
"Removed feature: 0, new relative accuracy: 0.938580\n",
"feature to remove: 4\n",
"Removed feature: 4, new relative accuracy: 0.987204\n",
"feature to remove: 2\n",
"Removed feature: 2, new relative accuracy: 0.992962\n",
"Accuracy on minimized data: 0.8100537221795856\n",
"{'ranges': {'3': [704.0, 782.0, 870.0, 951.0, 1588.0, 1647.5, 1684.0, 1805.0, 1923.0, 2168.5]}, 'categories': {}, 'untouched': ['2', '4', '1', '0']}\n"
]
}
],
@ -276,4 +264,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}