diff --git a/apt/anonymization/anonymizer.py b/apt/anonymization/anonymizer.py
index 99c898a..8a4f95d 100644
--- a/apt/anonymization/anonymizer.py
+++ b/apt/anonymization/anonymizer.py
@@ -22,19 +22,19 @@ class Anonymize:
"""
:param k: The privacy parameter that determines the number of records that will be indistinguishable from each
other (when looking at the quasi identifiers). Should be at least 2.
- :param quasi_identifiers: The indexes of features that need to be minimized.
- :param categorical_features: The list of categorical features indexes
+ :param quasi_identifiers: The features that need to be minimized.
+ :param categorical_features: The list of categorical features.
:param is_regression: Boolean param indicates that is is a regression problem.
"""
if k < 2:
raise ValueError("k should be a positive integer with a value of 2 or higher")
if quasi_identifiers is None or len(quasi_identifiers) < 1:
raise ValueError("The list of quasi-identifiers cannot be empty")
-
self.k = k
self.quasi_identifiers = quasi_identifiers
self.categorical_features = categorical_features
self.is_regression = is_regression
+ self.features_names = None
def anonymize(self, dataset: ArrayDataset) -> DATA_PANDAS_NUMPY_TYPE:
"""
@@ -51,6 +51,15 @@ class Anonymize:
self._features = [i for i in range(dataset.get_samples().shape[0])]
else:
raise ValueError('No data provided')
+ if not set(self.quasi_identifiers).issubset(set(self.features_names)):
+ raise ValueError('Quasi identifiers should bs a subset of the supplied features or indexes in range of '
+ 'the data columns')
+ if self.categorical_features and not set(self.categorical_features).issubset(set(self.features_names)):
+ raise ValueError('Categorical features should bs a subset of the supplied features or indexes in range of '
+ 'the data columns')
+ self.quasi_identifiers = [i for i, v in enumerate(self.features_names) if v in self.quasi_identifiers]
+ if self.categorical_features:
+ self.categorical_features = [i for i, v in enumerate(self.features_names) if v in self.categorical_features]
transformed = self._anonymize(dataset.get_samples().copy(), dataset.get_labels())
if dataset.is_pandas:
diff --git a/apt/utils/dataset_utils.py b/apt/utils/dataset_utils.py
index f99c6cc..2405f8f 100644
--- a/apt/utils/dataset_utils.py
+++ b/apt/utils/dataset_utils.py
@@ -273,7 +273,7 @@ def get_nursery_dataset(raw: bool = True, test_set: float = 0.2, transform_socia
raise Exception("Bad label value: %s" % value)
data["label"] = data["label"].apply(modify_label)
- data["children"] = data["children"].apply(lambda x: 4 if x == "more" else x)
+ data["children"] = data["children"].apply(lambda x: "4" if x == "more" else x)
if transform_social:
diff --git a/apt/utils/datasets/datasets.py b/apt/utils/datasets/datasets.py
index ebcd7cb..29dd4e9 100644
--- a/apt/utils/datasets/datasets.py
+++ b/apt/utils/datasets/datasets.py
@@ -18,7 +18,6 @@ from torch import Tensor
logger = logging.getLogger(__name__)
-
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
@@ -113,7 +112,6 @@ class StoredDataset(Dataset):
if unzip:
StoredDataset.extract_archive(zip_path=file_path, dest_path=dest_path, remove_archive=False)
-
@staticmethod
def extract_archive(zip_path: str, dest_path=None, remove_archive=False):
"""
@@ -164,7 +162,8 @@ class StoredDataset(Dataset):
class ArrayDataset(Dataset):
"""Dataset that is based on x and y arrays (e.g., numpy/pandas/list...)"""
- def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, **kwargs):
+ def __init__(self, x: INPUT_DATA_ARRAY_TYPE, y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names=None,
+ **kwargs):
"""
ArrayDataset constructor.
:param x: collection of data samples
@@ -172,10 +171,12 @@ class ArrayDataset(Dataset):
:param kwargs: dataset parameters
"""
self.is_pandas = False
- self.features_names = None
+ self.features_names = features_names
self._y = array2numpy(self, y) if y is not None else None
self._x = array2numpy(self, x)
if self.is_pandas:
+ if features_names and not np.array_equal(features_names, x.columns):
+ raise ValueError("The supplied features are not the same as in the data features")
self.features_names = x.columns
if y is not None and len(self._x) != len(self._y):
@@ -213,7 +214,6 @@ class PytorchData(Dataset):
else:
self.__getitem__ = self.get_sample_item
-
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return array2numpy(self._x)
@@ -244,6 +244,7 @@ class DatasetFactory:
:param name: dataset name
:return:
"""
+
def inner_wrapper(wrapped_class: Dataset) -> Any:
if name in cls.registry:
logger.warning('Dataset %s already exists. Will replace it', name)
diff --git a/notebooks/attribute_inference_anonymization_nursery.ipynb b/notebooks/attribute_inference_anonymization_nursery.ipynb
index 34fa296..bfba540 100644
--- a/notebooks/attribute_inference_anonymization_nursery.ipynb
+++ b/notebooks/attribute_inference_anonymization_nursery.ipynb
@@ -29,7 +29,7 @@
},
{
"cell_type": "code",
- "execution_count": 136,
+ "execution_count": 1,
"metadata": {},
"outputs": [
{
@@ -37,7 +37,7 @@
"text/plain": " parents has_nurs form children housing finance \\\n8450 pretentious very_crit foster 1 less_conv convenient \n12147 great_pret very_crit complete 1 critical inconv \n2780 usual critical complete 4 less_conv convenient \n11924 great_pret critical foster 1 critical convenient \n59 usual proper complete 2 convenient convenient \n... ... ... ... ... ... ... \n5193 pretentious less_proper complete 1 convenient inconv \n1375 usual less_proper incomplete 2 less_conv convenient \n10318 great_pret less_proper foster 4 convenient convenient \n6396 pretentious improper completed 3 less_conv convenient \n485 usual proper incomplete 1 critical inconv \n\n social health \n8450 1 not_recom \n12147 1 recommended \n2780 1 not_recom \n11924 1 not_recom \n59 0 not_recom \n... ... ... \n5193 0 recommended \n1375 1 priority \n10318 0 priority \n6396 1 recommended \n485 1 not_recom \n\n[10366 rows x 8 columns]",
"text/html": "
\n\n
\n \n \n | \n parents | \n has_nurs | \n form | \n children | \n housing | \n finance | \n social | \n health | \n
\n \n \n \n | 8450 | \n pretentious | \n very_crit | \n foster | \n 1 | \n less_conv | \n convenient | \n 1 | \n not_recom | \n
\n \n | 12147 | \n great_pret | \n very_crit | \n complete | \n 1 | \n critical | \n inconv | \n 1 | \n recommended | \n
\n \n | 2780 | \n usual | \n critical | \n complete | \n 4 | \n less_conv | \n convenient | \n 1 | \n not_recom | \n
\n \n | 11924 | \n great_pret | \n critical | \n foster | \n 1 | \n critical | \n convenient | \n 1 | \n not_recom | \n
\n \n | 59 | \n usual | \n proper | \n complete | \n 2 | \n convenient | \n convenient | \n 0 | \n not_recom | \n
\n \n | ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n
\n \n | 5193 | \n pretentious | \n less_proper | \n complete | \n 1 | \n convenient | \n inconv | \n 0 | \n recommended | \n
\n \n | 1375 | \n usual | \n less_proper | \n incomplete | \n 2 | \n less_conv | \n convenient | \n 1 | \n priority | \n
\n \n | 10318 | \n great_pret | \n less_proper | \n foster | \n 4 | \n convenient | \n convenient | \n 0 | \n priority | \n
\n \n | 6396 | \n pretentious | \n improper | \n completed | \n 3 | \n less_conv | \n convenient | \n 1 | \n recommended | \n
\n \n | 485 | \n usual | \n proper | \n incomplete | \n 1 | \n critical | \n inconv | \n 1 | \n not_recom | \n
\n \n
\n
10366 rows × 8 columns
\n
"
},
- "execution_count": 136,
+ "execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
@@ -63,7 +63,7 @@
},
{
"cell_type": "code",
- "execution_count": 137,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -104,7 +104,7 @@
},
{
"cell_type": "code",
- "execution_count": 138,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -140,7 +140,7 @@
},
{
"cell_type": "code",
- "execution_count": 139,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -178,14 +178,14 @@
},
{
"cell_type": "code",
- "execution_count": 140,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "0.5076210688790276\n"
+ "0.5122515917422342\n"
]
}
],
@@ -225,7 +225,7 @@
},
{
"cell_type": "code",
- "execution_count": 141,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -233,7 +233,7 @@
"text/plain": " parents has_nurs form children housing finance \\\n0 pretentious very_crit foster 1 less_conv convenient \n1 great_pret very_crit complete 1 critical inconv \n2 usual critical complete 4 less_conv convenient \n3 great_pret critical foster 1 critical convenient \n4 usual proper complete 2 convenient convenient \n... ... ... ... ... ... ... \n10361 pretentious less_proper complete 1 convenient inconv \n10362 usual less_proper incomplete 2 less_conv convenient \n10363 great_pret less_proper foster 4 convenient convenient \n10364 pretentious improper completed 3 less_conv convenient \n10365 usual proper incomplete 1 critical convenient \n\n social health \n0 0 not_recom \n1 1 recommended \n2 0 not_recom \n3 0 not_recom \n4 0 not_recom \n... ... ... \n10361 0 recommended \n10362 1 priority \n10363 0 priority \n10364 1 recommended \n10365 0 not_recom \n\n[10366 rows x 8 columns]",
"text/html": "\n\n
\n \n \n | \n parents | \n has_nurs | \n form | \n children | \n housing | \n finance | \n social | \n health | \n
\n \n \n \n | 0 | \n pretentious | \n very_crit | \n foster | \n 1 | \n less_conv | \n convenient | \n 0 | \n not_recom | \n
\n \n | 1 | \n great_pret | \n very_crit | \n complete | \n 1 | \n critical | \n inconv | \n 1 | \n recommended | \n
\n \n | 2 | \n usual | \n critical | \n complete | \n 4 | \n less_conv | \n convenient | \n 0 | \n not_recom | \n
\n \n | 3 | \n great_pret | \n critical | \n foster | \n 1 | \n critical | \n convenient | \n 0 | \n not_recom | \n
\n \n | 4 | \n usual | \n proper | \n complete | \n 2 | \n convenient | \n convenient | \n 0 | \n not_recom | \n
\n \n | ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n
\n \n | 10361 | \n pretentious | \n less_proper | \n complete | \n 1 | \n convenient | \n inconv | \n 0 | \n recommended | \n
\n \n | 10362 | \n usual | \n less_proper | \n incomplete | \n 2 | \n less_conv | \n convenient | \n 1 | \n priority | \n
\n \n | 10363 | \n great_pret | \n less_proper | \n foster | \n 4 | \n convenient | \n convenient | \n 0 | \n priority | \n
\n \n | 10364 | \n pretentious | \n improper | \n completed | \n 3 | \n less_conv | \n convenient | \n 1 | \n recommended | \n
\n \n | 10365 | \n usual | \n proper | \n incomplete | \n 1 | \n critical | \n convenient | \n 0 | \n not_recom | \n
\n \n
\n
10366 rows × 8 columns
\n
"
},
- "execution_count": 141,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -254,14 +254,14 @@
},
{
"cell_type": "code",
- "execution_count": 142,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "7585"
},
- "execution_count": 142,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -273,14 +273,14 @@
},
{
"cell_type": "code",
- "execution_count": 143,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "5766"
},
- "execution_count": 143,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -299,7 +299,7 @@
},
{
"cell_type": "code",
- "execution_count": 144,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -332,7 +332,7 @@
},
{
"cell_type": "code",
- "execution_count": 145,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -368,14 +368,14 @@
},
{
"cell_type": "code",
- "execution_count": 146,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "0.5218985143739148\n"
+ "0.5245996527107852\n"
]
}
],
@@ -399,7 +399,7 @@
},
{
"cell_type": "code",
- "execution_count": 147,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
@@ -444,15 +444,15 @@
},
{
"cell_type": "code",
- "execution_count": 148,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "(0.9322033898305084, 0.01066925315227934)\n",
- "(0.9806763285024155, 0.03937924345295829)\n"
+ "(1.0, 0.019204655674102813)\n",
+ "(0.9829787234042553, 0.04481086323957323)\n"
]
}
],
@@ -483,7 +483,7 @@
},
{
"cell_type": "code",
- "execution_count": 149,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -493,14 +493,14 @@
},
{
"cell_type": "code",
- "execution_count": 150,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "4226"
},
- "execution_count": 150,
+ "execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -519,7 +519,7 @@
},
{
"cell_type": "code",
- "execution_count": 151,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
@@ -552,7 +552,7 @@
},
{
"cell_type": "code",
- "execution_count": 152,
+ "execution_count": 17,
"metadata": {},
"outputs": [
{
@@ -588,14 +588,14 @@
},
{
"cell_type": "code",
- "execution_count": 153,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "0.5184256222265098\n"
+ "0.515820953115956\n"
]
}
],
@@ -612,7 +612,7 @@
},
{
"cell_type": "code",
- "execution_count": 154,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -621,8 +621,8 @@
"text": [
"(0.49415432579890883, 0.48976438779451525)\n",
"(0.49415432579890883, 0.48976438779451525)\n",
- "(0.9322033898305084, 0.01066925315227934)\n",
- "(1.0, 0.03161978661493695)\n"
+ "(1.0, 0.019204655674102813)\n",
+ "(1.0, 0.026382153249272552)\n"
]
}
],
@@ -655,34 +655,9 @@
},
{
"cell_type": "code",
- "execution_count": 155,
+ "execution_count": 20,
"metadata": {},
- "outputs": [
- {
- "ename": "TypeError",
- "evalue": "argument must be a string or number",
- "output_type": "error",
- "traceback": [
- "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
- "\u001B[0;31mTypeError\u001B[0m Traceback (most recent call last)",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/sklearn/preprocessing/_label.py:112\u001B[0m, in \u001B[0;36m_encode\u001B[0;34m(values, uniques, encode, check_unknown)\u001B[0m\n\u001B[1;32m 111\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 112\u001B[0m res \u001B[38;5;241m=\u001B[39m \u001B[43m_encode_python\u001B[49m\u001B[43m(\u001B[49m\u001B[43mvalues\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43muniques\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mencode\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 113\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m:\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/sklearn/preprocessing/_label.py:60\u001B[0m, in \u001B[0;36m_encode_python\u001B[0;34m(values, uniques, encode)\u001B[0m\n\u001B[1;32m 59\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m uniques \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m---> 60\u001B[0m uniques \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43msorted\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mset\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mvalues\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 61\u001B[0m uniques \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(uniques, dtype\u001B[38;5;241m=\u001B[39mvalues\u001B[38;5;241m.\u001B[39mdtype)\n",
- "\u001B[0;31mTypeError\u001B[0m: '<' not supported between instances of 'int' and 'str'",
- "\nDuring handling of the above exception, another exception occurred:\n",
- "\u001B[0;31mTypeError\u001B[0m Traceback (most recent call last)",
- "Input \u001B[0;32mIn [155]\u001B[0m, in \u001B[0;36m\u001B[0;34m()\u001B[0m\n\u001B[1;32m 2\u001B[0m QI2_indexes \u001B[38;5;241m=\u001B[39m [i \u001B[38;5;28;01mfor\u001B[39;00m i, v \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(features) \u001B[38;5;28;01mif\u001B[39;00m v \u001B[38;5;129;01min\u001B[39;00m QI2]\n\u001B[1;32m 3\u001B[0m anonymizer3 \u001B[38;5;241m=\u001B[39m Anonymize(\u001B[38;5;241m100\u001B[39m, QI2_indexes, categorical_features\u001B[38;5;241m=\u001B[39mcategorical_features_indexes)\n\u001B[0;32m----> 4\u001B[0m anon3 \u001B[38;5;241m=\u001B[39m \u001B[43manonymizer3\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43manonymize\u001B[49m\u001B[43m(\u001B[49m\u001B[43mArrayDataset\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mx_train_predictions\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/apt/anonymization/anonymizer.py:55\u001B[0m, in \u001B[0;36mAnonymize.anonymize\u001B[0;34m(self, dataset)\u001B[0m\n\u001B[1;32m 52\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 53\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mNo data provided\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[0;32m---> 55\u001B[0m transformed \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_anonymize\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdataset\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget_samples\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcopy\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdataset\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget_labels\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 56\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m dataset\u001B[38;5;241m.\u001B[39mis_pandas:\n\u001B[1;32m 57\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m pd\u001B[38;5;241m.\u001B[39mDataFrame(transformed, columns\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_features)\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/apt/anonymization/anonymizer.py:68\u001B[0m, in \u001B[0;36mAnonymize._anonymize\u001B[0;34m(self, x, y)\u001B[0m\n\u001B[1;32m 66\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcategorical_features:\n\u001B[1;32m 67\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mwhen supplying an array with non-numeric data, categorical_features must be defined\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[0;32m---> 68\u001B[0m x_prepared \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_modify_categorical_features\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx_anonymizer_train\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 69\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 70\u001B[0m x_prepared \u001B[38;5;241m=\u001B[39m x_anonymizer_train\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/apt/anonymization/anonymizer.py:144\u001B[0m, in \u001B[0;36mAnonymize._modify_categorical_features\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 142\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_modify_categorical_features\u001B[39m(\u001B[38;5;28mself\u001B[39m, x):\n\u001B[1;32m 143\u001B[0m encoder \u001B[38;5;241m=\u001B[39m OneHotEncoder()\n\u001B[0;32m--> 144\u001B[0m one_hot_encoded \u001B[38;5;241m=\u001B[39m \u001B[43mencoder\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit_transform\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 145\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m one_hot_encoded\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:372\u001B[0m, in \u001B[0;36mOneHotEncoder.fit_transform\u001B[0;34m(self, X, y)\u001B[0m\n\u001B[1;32m 352\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 353\u001B[0m \u001B[38;5;124;03mFit OneHotEncoder to X, then transform X.\u001B[39;00m\n\u001B[1;32m 354\u001B[0m \n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 369\u001B[0m \u001B[38;5;124;03m Transformed input.\u001B[39;00m\n\u001B[1;32m 370\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 371\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_validate_keywords()\n\u001B[0;32m--> 372\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit_transform\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my\u001B[49m\u001B[43m)\u001B[49m\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/sklearn/base.py:571\u001B[0m, in \u001B[0;36mTransformerMixin.fit_transform\u001B[0;34m(self, X, y, **fit_params)\u001B[0m\n\u001B[1;32m 567\u001B[0m \u001B[38;5;66;03m# non-optimized default implementation; override when a better\u001B[39;00m\n\u001B[1;32m 568\u001B[0m \u001B[38;5;66;03m# method is possible for a given clustering algorithm\u001B[39;00m\n\u001B[1;32m 569\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m y \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 570\u001B[0m \u001B[38;5;66;03m# fit method of arity 1 (unsupervised transformation)\u001B[39;00m\n\u001B[0;32m--> 571\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mfit_params\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241m.\u001B[39mtransform(X)\n\u001B[1;32m 572\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 573\u001B[0m \u001B[38;5;66;03m# fit method of arity 2 (supervised transformation)\u001B[39;00m\n\u001B[1;32m 574\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfit(X, y, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mfit_params)\u001B[38;5;241m.\u001B[39mtransform(X)\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:347\u001B[0m, in \u001B[0;36mOneHotEncoder.fit\u001B[0;34m(self, X, y)\u001B[0m\n\u001B[1;32m 330\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 331\u001B[0m \u001B[38;5;124;03mFit OneHotEncoder to X.\u001B[39;00m\n\u001B[1;32m 332\u001B[0m \n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 344\u001B[0m \u001B[38;5;124;03mself\u001B[39;00m\n\u001B[1;32m 345\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 346\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_validate_keywords()\n\u001B[0;32m--> 347\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_fit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mhandle_unknown\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mhandle_unknown\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 348\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdrop_idx_ \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compute_drop_idx()\n\u001B[1;32m 349\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:86\u001B[0m, in \u001B[0;36m_BaseEncoder._fit\u001B[0;34m(self, X, handle_unknown)\u001B[0m\n\u001B[1;32m 84\u001B[0m Xi \u001B[38;5;241m=\u001B[39m X_list[i]\n\u001B[1;32m 85\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcategories \u001B[38;5;241m==\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mauto\u001B[39m\u001B[38;5;124m'\u001B[39m:\n\u001B[0;32m---> 86\u001B[0m cats \u001B[38;5;241m=\u001B[39m \u001B[43m_encode\u001B[49m\u001B[43m(\u001B[49m\u001B[43mXi\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 87\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 88\u001B[0m cats \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcategories[i], dtype\u001B[38;5;241m=\u001B[39mXi\u001B[38;5;241m.\u001B[39mdtype)\n",
- "File \u001B[0;32m~/PycharmProjects/ai-privacy-toolkit-internal/venv/lib/python3.8/site-packages/sklearn/preprocessing/_label.py:114\u001B[0m, in \u001B[0;36m_encode\u001B[0;34m(values, uniques, encode, check_unknown)\u001B[0m\n\u001B[1;32m 112\u001B[0m res \u001B[38;5;241m=\u001B[39m _encode_python(values, uniques, encode)\n\u001B[1;32m 113\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m:\n\u001B[0;32m--> 114\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124margument must be a string or number\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 115\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m res\n\u001B[1;32m 116\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n",
- "\u001B[0;31mTypeError\u001B[0m: argument must be a string or number"
- ]
- }
- ],
+ "outputs": [],
"source": [
"QI2 = [\"parents\", \"has_nurs\", \"form\", \"children\", \"housing\", \"finance\", \"social\", \"health\"]\n",
"QI2_indexes = [i for i, v in enumerate(features) if v in QI2]\n",
@@ -692,9 +667,18 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "39"
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"# number of distinct rows in anonymized data\n",
"len(anon3.drop_duplicates())"
@@ -702,9 +686,19 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 22,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Anonymized model accuracy: 0.751929012345679\n",
+ "BB attack accuracy: 1.0\n",
+ "WB attack accuracy: 0.5187150299054601\n"
+ ]
+ }
+ ],
"source": [
"anon3_str = anon3.astype(str)\n",
"anon3_encoded = OneHotEncoder(sparse=False).fit_transform(anon3_str)\n",
@@ -742,9 +736,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 23,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(0.49415432579890883, 0.48976438779451525)\n",
+ "(0.49415432579890883, 0.48976438779451525)\n",
+ "(1.0, 0.019204655674102813)\n",
+ "(1.0, 0.032201745877788554)\n"
+ ]
+ }
+ ],
"source": [
"# black-box regular\n",
"print(calc_precision_recall(inferred_train_bb, x_train_feature))\n",
diff --git a/tests/test_anonymizer.py b/tests/test_anonymizer.py
index d7072e4..83710cd 100644
--- a/tests/test_anonymizer.py
+++ b/tests/test_anonymizer.py
@@ -44,7 +44,7 @@ def test_anonymize_pandas_adult():
QI_indexes = [i for i, v in enumerate(features) if v in QI]
categorical_features_indexes = [i for i, v in enumerate(features) if v in categorical_features]
anonymizer = Anonymize(k, QI_indexes, categorical_features=categorical_features_indexes)
- anon = anonymizer.anonymize(ArrayDataset(x_train, pred))
+ anon = anonymizer.anonymize(ArrayDataset(x_train, pred, features))
assert(anon.loc[:, QI].drop_duplicates().shape[0] < x_train.loc[:, QI].drop_duplicates().shape[0])
assert (anon.loc[:, QI].value_counts().min() >= k)
|