diff --git a/apt/anonymization/anonymizer.py b/apt/anonymization/anonymizer.py index 8a4f95d..94eb403 100644 --- a/apt/anonymization/anonymizer.py +++ b/apt/anonymization/anonymizer.py @@ -45,10 +45,10 @@ class Anonymize: :return: An array containing the anonymized training dataset. """ if dataset.features_names is not None: - self._features = dataset.features_names + self.features_names = dataset.features_names # if features is None, use numbers instead of names elif dataset.get_samples().shape[0] != 0: - self._features = [i for i in range(dataset.get_samples().shape[0])] + self.features_names = [i for i in range(dataset.get_samples().shape[0])] else: raise ValueError('No data provided') if not set(self.quasi_identifiers).issubset(set(self.features_names)): @@ -63,7 +63,7 @@ class Anonymize: transformed = self._anonymize(dataset.get_samples().copy(), dataset.get_labels()) if dataset.is_pandas: - return pd.DataFrame(transformed, columns=self._features) + return pd.DataFrame(transformed, columns=self.features_names) else: return transformed diff --git a/tests/test_anonymizer.py b/tests/test_anonymizer.py index 83710cd..464bd20 100644 --- a/tests/test_anonymizer.py +++ b/tests/test_anonymizer.py @@ -41,9 +41,7 @@ def test_anonymize_pandas_adult(): 'native-country'] categorical_features = ['workclass', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country'] - QI_indexes = [i for i, v in enumerate(features) if v in QI] - categorical_features_indexes = [i for i, v in enumerate(features) if v in categorical_features] - anonymizer = Anonymize(k, QI_indexes, categorical_features=categorical_features_indexes) + anonymizer = Anonymize(k, QI, categorical_features=categorical_features) anon = anonymizer.anonymize(ArrayDataset(x_train, pred, features)) assert(anon.loc[:, QI].drop_duplicates().shape[0] < x_train.loc[:, QI].drop_duplicates().shape[0]) @@ -62,9 +60,7 @@ def test_anonymize_pandas_nursery(): k = 100 QI = ["finance", "social", "health"] categorical_features = ["parents", "has_nurs", "form", "housing", "finance", "social", "health", 'children'] - QI_indexes = [i for i, v in enumerate(features) if v in QI] - categorical_features_indexes = [i for i, v in enumerate(features) if v in categorical_features] - anonymizer = Anonymize(k, QI_indexes, categorical_features=categorical_features_indexes) + anonymizer = Anonymize(k, QI, categorical_features=categorical_features) anon = anonymizer.anonymize(ArrayDataset(x_train, pred)) assert(anon.loc[:, QI].drop_duplicates().shape[0] < x_train.loc[:, QI].drop_duplicates().shape[0])