Data and Model wrappers (#26)

* Squashed commit of wrappers:

    Wrapper minimizer

    * apply dataset wrapper on minimizer
    * apply changes on minimization notebook
    * add black_box_access and unlimited_queries params

    Dataset wrapper anonymizer

    Add features_names to ArrayDataset
    and allow providing features names in QI and Cat features not just indexes

    update notebooks

    categorical features and QI passed by indexes
    dataset include feature names and is_pandas param

    add pytorch Dataset

    Remove redundant code.
    Use data wrappers in model wrapper APIs.

    add generic dataset components 

    Create initial version of wrappers for models

* Fix handling of categorical features
This commit is contained in:
abigailgold 2022-04-27 12:33:27 +03:00 committed by GitHub
parent d53818644e
commit 2b2dab6bef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 1340 additions and 752 deletions

View file

@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -44,6 +44,18 @@
" [ 26. 11. 0. 0. 48.]\n",
" [ 27. 9. 0. 0. 40.]]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/9b/qbtw28w53355cvpjs4qn83yc0000gn/T/ipykernel_85828/3975777015.py:22: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" y_train = y_train.astype(np.int)\n",
"/var/folders/9b/qbtw28w53355cvpjs4qn83yc0000gn/T/ipykernel_85828/3975777015.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" y_test = y_test.astype(np.int)\n"
]
}
],
"source": [
@ -90,14 +102,14 @@
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Base model accuracy: 0.8075056814691972\n"
"Base model accuracy: 0.8074442601805786\n"
]
}
],
@ -126,9 +138,18 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 8,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/olasaadi/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/art/attacks/inference/membership_inference/black_box.py:262: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" self.attack_model.fit(np.c_[x_1, x_2], y_ready) # type: ignore\n"
]
}
],
"source": [
"from art.attacks.inference.membership_inference import MembershipInferenceBlackBox\n",
"\n",
@ -154,14 +175,14 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.5440363591696352\n"
"0.545264709495148\n"
]
}
],
@ -197,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 10,
"metadata": {},
"outputs": [
{
@ -215,6 +236,7 @@
}
],
"source": [
"from apt.utils.datasets import ArrayDataset\n",
"import os\n",
"import sys\n",
"sys.path.insert(0, os.path.abspath('..'))\n",
@ -223,22 +245,20 @@
"# QI = (age, education-num, capital-gain, hours-per-week)\n",
"QI = [0, 1, 2, 4]\n",
"anonymizer = Anonymize(100, QI)\n",
"anon = anonymizer.anonymize(x_train, x_train_predictions)\n",
"anon = anonymizer.anonymize(ArrayDataset(x_train, x_train_predictions))\n",
"print(anon)"
]
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6739"
]
"text/plain": "6739"
},
"execution_count": 104,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -250,16 +270,14 @@
},
{
"cell_type": "code",
"execution_count": 129,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"658"
]
"text/plain": "658"
},
"execution_count": 129,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -278,14 +296,14 @@
},
{
"cell_type": "code",
"execution_count": 130,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Anonymized model accuracy: 0.8304158221239482\n"
"Anonymized model accuracy: 0.83078434985566\n"
]
}
],
@ -308,14 +326,22 @@
},
{
"cell_type": "code",
"execution_count": 131,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/olasaadi/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/art/attacks/inference/membership_inference/black_box.py:262: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" self.attack_model.fit(np.c_[x_1, x_2], y_ready) # type: ignore\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.5034393809114359\n"
"0.5047291487532244\n"
]
}
],
@ -345,15 +371,15 @@
},
{
"cell_type": "code",
"execution_count": 132,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5298924372550654, 0.7806166318634075)\n",
"(0.5030507735890172, 0.5671293452892765)\n"
"(0.5312420517168291, 0.7696843139663432)\n",
"(0.5048372911169745, 0.4935511607910576)\n"
]
}
],
@ -419,4 +445,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}