diff --git a/apt/risk/data_assessment/attack_strategy_utils.py b/apt/risk/data_assessment/attack_strategy_utils.py index af12628..7817a66 100644 --- a/apt/risk/data_assessment/attack_strategy_utils.py +++ b/apt/risk/data_assessment/attack_strategy_utils.py @@ -169,26 +169,15 @@ class KNNAttackStrategyUtils(AttackStrategyUtils): differing_columns = [] df1_samples = df1.get_samples() df2_samples = df2.get_samples() - if df1.is_pandas: - for name, _ in df1_samples.items(): - is_categorical = name in categorical_features or is_categorical_dtype(df1_samples.dtypes[name]) - is_numeric = is_numeric_dtype(df1_samples.dtypes[name]) - KNNAttackStrategyUtils._column_statistical_test(df1_samples[name], df2_samples[name], name, - is_categorical, is_numeric, - self.distribution_comparison_numeric_test, - self.distribution_comparison_categorical_test, - self.distribution_comparison_alpha, - differing_columns) - else: - is_numeric = np.issubdtype(df1_samples.dtype, int) or np.issubdtype(df1_samples.dtype, float) + is_numeric = np.issubdtype(df1_samples.dtype, int) or np.issubdtype(df1_samples.dtype, float) - for i, column in enumerate(df1_samples.T): - is_categorical = i in categorical_features - KNNAttackStrategyUtils._column_statistical_test(df1_samples[:, i], df2_samples[:, i], i, - is_categorical, is_numeric, - self.distribution_comparison_numeric_test, - self.distribution_comparison_categorical_test, - self.distribution_comparison_alpha, differing_columns) + for i, column in enumerate(df1_samples.T): + is_categorical = i in categorical_features + KNNAttackStrategyUtils._column_statistical_test(df1_samples[:, i], df2_samples[:, i], i, + is_categorical, is_numeric, + self.distribution_comparison_numeric_test, + self.distribution_comparison_categorical_test, + self.distribution_comparison_alpha, differing_columns) return differing_columns def validate_distributions(self, original_data_members: ArrayDataset, original_data_non_members: ArrayDataset, diff --git a/tests/test_data_assessment_short_test.py b/tests/test_data_assessment_short_test.py index 7416d47..1ca11d5 100644 --- a/tests/test_data_assessment_short_test.py +++ b/tests/test_data_assessment_short_test.py @@ -1,3 +1,4 @@ +import pandas as pd import pytest from apt.anonymization import Anonymize @@ -52,6 +53,8 @@ def test_risk_anonymization(name, data, dataset_type, mgr): categorical_features = [] elif "nursery" in name: preprocessed_x_train, preprocessed_x_test, categorical_features = preprocess_nursery_x_data(x_train, x_test) + preprocessed_x_train = pd.DataFrame(preprocessed_x_train) + preprocessed_x_test = pd.DataFrame(preprocessed_x_test) QI = list(range(15, 20)) anonymizer = Anonymize(ANON_K, QI, train_only_QI=True) else: