Extract common code to methods

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-10-18 12:22:26 +03:00
parent 570c6f8966
commit ea8564bc4b

View file

@ -315,17 +315,10 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._attach_cells_representatives(x_prepared, used_x_train, y_train, nodes) self._attach_cells_representatives(x_prepared, used_x_train, y_train, nodes)
# self._cells currently holds the generalization created from the tree leaves # self._cells currently holds the generalization created from the tree leaves
self._calculate_generalizations(x_test) generalized = self._generalize(x_test, x_prepared_test, nodes)
if self.generalize_using_transform:
generalized = self._generalize_from_tree(x_test, x_prepared_test, nodes, self.cells, self._cells_by_id)
else:
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
# check accuracy # check accuracy
if self.encoder: accuracy = self._calculate_accuracy(generalized, y_test, self.estimator, self.encoder)
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
else:
accuracy = self.estimator.score(ArrayDataset(generalized, y_test))
print('Initial accuracy of model on generalized data, relative to original model predictions ' print('Initial accuracy of model on generalized data, relative to original model predictions '
'(base generalization derived from tree, before improvements): %f' % accuracy) '(base generalization derived from tree, before improvements): %f' % accuracy)
@ -348,17 +341,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._attach_cells_representatives(x_prepared, used_x_train, y_train, nodes) self._attach_cells_representatives(x_prepared, used_x_train, y_train, nodes)
self._calculate_generalizations(x_test) generalized = self._generalize(x_test, x_prepared_test, nodes)
if self.generalize_using_transform: accuracy = self._calculate_accuracy(generalized, y_test, self.estimator, self.encoder)
generalized = self._generalize_from_tree(x_test, x_prepared_test, nodes, self.cells,
self._cells_by_id)
else:
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
if self.encoder:
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
else:
accuracy = self.estimator.score(ArrayDataset(generalized, y_test))
# if accuracy passed threshold roll back to previous iteration generalizations # if accuracy passed threshold roll back to previous iteration generalizations
if accuracy < self.target_accuracy: if accuracy < self.target_accuracy:
self.cells = cells_previous_iter self.cells = cells_previous_iter
@ -381,17 +365,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
if removed_feature is None: if removed_feature is None:
break break
self._calculate_generalizations(x_test) generalized = self._generalize(x_test, x_prepared_test, nodes)
if self.generalize_using_transform: accuracy = self._calculate_accuracy(generalized, y_test, self.estimator, self.encoder)
generalized = self._generalize_from_tree(x_test, x_prepared_test, nodes, self.cells,
self._cells_by_id)
else:
generalized = self._generalize_from_generalizations(x_test, self.generalizations)
if self.encoder:
accuracy = self.estimator.score(ArrayDataset(self.encoder.transform(generalized), y_test))
else:
accuracy = self.estimator.score(ArrayDataset(generalized, y_test))
print('Removed feature: %s, new relative accuracy: %f' % (removed_feature, accuracy)) print('Removed feature: %s, new relative accuracy: %f' % (removed_feature, accuracy))
# self._cells currently holds the chosen generalization based on target accuracy # self._cells currently holds the chosen generalization based on target accuracy
@ -918,6 +893,15 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
return original_data_generalized return original_data_generalized
def _generalize(self, data, data_prepared, nodes):
self._calculate_generalizations(data)
if self.generalize_using_transform:
generalized = self._generalize_from_tree(data, data_prepared, nodes, self.cells,
self._cells_by_id)
else:
generalized = self._generalize_from_generalizations(data, self.generalizations)
return generalized
@staticmethod @staticmethod
def _map_to_ranges_categories(samples, ranges, categories): def _map_to_ranges_categories(samples, ranges, categories):
all_sample_indexes = [] all_sample_indexes = []
@ -987,21 +971,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
feature_data[feature], feature_data[feature],
total) total)
if feature_ncp > 0: if feature_ncp > 0:
# divide by accuracy gain feature_ncp = self._normalize_ncp_by_accuracy_gain(original_data, prepared_data, nodes, feature,
new_cells = copy.deepcopy(self.cells) feature_ncp, labels, current_accuracy)
cells_by_id = copy.deepcopy(self._cells_by_id)
GeneralizeToRepresentative._remove_feature_from_cells(new_cells, cells_by_id, feature)
generalized = self._generalize_from_tree(original_data, prepared_data, nodes, new_cells,
cells_by_id)
if self.encoder:
accuracy_gain = self.estimator.score(ArrayDataset(self.encoder.transform(generalized),
labels)) - current_accuracy
else:
accuracy_gain = self.estimator.score(ArrayDataset(generalized, labels)) - current_accuracy
if accuracy_gain < 0:
accuracy_gain = 0
if accuracy_gain != 0:
feature_ncp = feature_ncp / accuracy_gain
if feature_ncp < range_min: if feature_ncp < range_min:
range_min = feature_ncp range_min = feature_ncp
@ -1017,22 +988,9 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
feature_data[feature], feature_data[feature],
total) total)
if feature_ncp > 0: if feature_ncp > 0:
# divide by accuracy loss feature_ncp = self._normalize_ncp_by_accuracy_gain(original_data, prepared_data, nodes, feature,
new_cells = copy.deepcopy(self.cells) feature_ncp, labels, current_accuracy)
cells_by_id = copy.deepcopy(self._cells_by_id)
GeneralizeToRepresentative._remove_feature_from_cells(new_cells, cells_by_id, feature)
generalized = self._generalize_from_tree(original_data, prepared_data, nodes, new_cells,
cells_by_id)
if self.encoder:
accuracy_gain = self.estimator.score(ArrayDataset(self.encoder.transform(generalized),
labels)) - current_accuracy
else:
accuracy_gain = self.estimator.score(ArrayDataset(generalized, labels)) - current_accuracy
if accuracy_gain < 0:
accuracy_gain = 0
if accuracy_gain != 0:
feature_ncp = feature_ncp / accuracy_gain
if feature_ncp < range_min: if feature_ncp < range_min:
range_min = feature_ncp range_min = feature_ncp
remove_feature = feature remove_feature = feature
@ -1062,6 +1020,21 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
feature_ncp += cell_ncp feature_ncp += cell_ncp
return feature_ncp return feature_ncp
def _normalize_ncp_by_accuracy_gain(self, original_data, prepared_data, nodes, feature, feature_ncp, labels,
current_accuracy):
new_cells = copy.deepcopy(self.cells)
cells_by_id = copy.deepcopy(self._cells_by_id)
GeneralizeToRepresentative._remove_feature_from_cells(new_cells, cells_by_id, feature)
generalized = self._generalize_from_tree(original_data, prepared_data, nodes, new_cells,
cells_by_id)
accuracy = self._calculate_accuracy(generalized, labels, self.estimator, self.encoder)
accuracy_gain = accuracy - current_accuracy
if accuracy_gain < 0:
accuracy_gain = 0
if accuracy_gain != 0:
feature_ncp = feature_ncp / accuracy_gain
return feature_ncp
def _calculate_generalizations(self, samples: Optional[pd.DataFrame] = None): def _calculate_generalizations(self, samples: Optional[pd.DataFrame] = None):
ranges, range_representatives = self._calculate_ranges(self.cells) ranges, range_representatives = self._calculate_ranges(self.cells)
categories, category_representatives = self._calculate_categories(self.cells) categories, category_representatives = self._calculate_categories(self.cells)
@ -1281,3 +1254,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
for feature in to_remove: for feature in to_remove:
del generalizations['categories'][feature] del generalizations['categories'][feature]
@staticmethod
def _calculate_accuracy(generalized, y_test, estimator, encoder):
generalized_data = encoder.transform(generalized) if encoder else generalized
return estimator.score(ArrayDataset(generalized_data, y_test))