diff --git a/apt/anonymization/anonymizer.py b/apt/anonymization/anonymizer.py index a093402..0f90a2d 100644 --- a/apt/anonymization/anonymizer.py +++ b/apt/anonymization/anonymizer.py @@ -66,6 +66,7 @@ class Anonymize: self.anonymizer = DecisionTreeRegressor(random_state=10, min_samples_split=2, min_samples_leaf=self.k) else: self.anonymizer = DecisionTreeClassifier(random_state=10, min_samples_split=2, min_samples_leaf=self.k) + self.anonymizer.fit(x_prepared, y) cells_by_id = self._calculate_cells(x, x_prepared) return self._anonymize_data_numpy(x, x_prepared, cells_by_id) @@ -80,6 +81,8 @@ class Anonymize: self.anonymizer = DecisionTreeRegressor(random_state=10, min_samples_split=2, min_samples_leaf=self.k) else: self.anonymizer = DecisionTreeClassifier(random_state=10, min_samples_split=2, min_samples_leaf=self.k) + if len(y.shape) > 1: + y = np.argmax(y, axis=1) self.anonymizer.fit(x_prepared, y) cells_by_id = self._calculate_cells(x, x_prepared) return self._anonymize_data_pandas(x, x_prepared, cells_by_id)