Notebook fix

This commit is contained in:
abigailt 2021-08-18 07:51:23 +03:00
parent 43952e2332
commit e44da7d1b5

View file

@ -147,8 +147,10 @@
"\n", "\n",
"# default target_accuracy is 0.998\n", "# default target_accuracy is 0.998\n",
"minimizer = GeneralizeToRepresentative(model)\n", "minimizer = GeneralizeToRepresentative(model)\n",
"# Can be done either on training or test data. Doing it with test data is better as the resulting accuracy on test\n", "\n",
"# data will be closer to the desired target accuracy (when working with training data it could result in a larger gap)\n", "# Fitting the minimizar can be done either on training or test data. Doing it with test data is better as the \n",
"# resulting accuracy on test data will be closer to the desired target accuracy (when working with training \n",
"# data it could result in a larger gap)\n",
"# Don't forget to leave a hold-out set for final validation!\n", "# Don't forget to leave a hold-out set for final validation!\n",
"X_generalizer_train, x_test, y_generalizer_train, y_test = train_test_split(x_test, y_test, stratify=y_test,\n", "X_generalizer_train, x_test, y_generalizer_train, y_test = train_test_split(x_test, y_test, stratify=y_test,\n",
" test_size = 0.4, random_state = 38)\n", " test_size = 0.4, random_state = 38)\n",
@ -195,7 +197,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -218,8 +220,8 @@
} }
], ],
"source": [ "source": [
"# We allow a 2.5% deviation in accuracy from the original model accuracy\n", "# We allow a 1% deviation in accuracy from the original model accuracy\n",
"minimizer2 = GeneralizeToRepresentative(model, target_accuracy=0.975)\n", "minimizer2 = GeneralizeToRepresentative(model, target_accuracy=0.99)\n",
"\n", "\n",
"minimizer2.fit(X_generalizer_train, x_train_predictions)\n", "minimizer2.fit(X_generalizer_train, x_train_predictions)\n",
"transformed2 = minimizer2.transform(x_test)\n", "transformed2 = minimizer2.transform(x_test)\n",