Minimization fixes (#12)

* Fixes related to corner cases in calculating generalizations

* Fix print

* Fix corner cases in transform as well

* Improve prints + bug fixes in calculation of feature to remove

* Notebook demonstrating ai minimization
This commit is contained in:
abigailgold 2021-08-17 21:19:48 +03:00 committed by GitHub
parent d2591d7840
commit 43952e2332
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 326 additions and 47 deletions

View file

@ -117,6 +117,10 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self.cells = params['cells']
return self
@property
def generalizations(self):
return self.generalizations_
def fit_transform(self, X=None, y=None):
"""Learns the generalizations based on training data, and applies them to the data.
@ -206,40 +210,45 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
nodes = self._get_nodes_level(0)
self._attach_cells_representatives(X_train, y_train, nodes)
# self.cells_ currently holds the generalization created from the tree leaves
self._calculate_generalizations()
# apply generalizations to test data
generalized = self._generalize(X_test, nodes, self.cells_, self.cells_by_id_)
# check accuracy
accuracy = self.estimator.score(generalized, y_test)
print('Initial accuracy is %f' % accuracy)
print('Initial accuracy of model on generalized data, relative to original model predictions '
'(base generalization derived from tree, before improvements): %f' % accuracy)
# if accuracy above threshold, improve generalization
if accuracy > self.target_accuracy:
print('Improving generalizations')
level = 1
while accuracy > self.target_accuracy:
nodes = self._get_nodes_level(level)
self._calculate_level_cells(level)
self._attach_cells_representatives(X_train, y_train, nodes)
self._calculate_generalizations()
generalized = self._generalize(X_test, nodes, self.cells_,
self.cells_by_id_)
accuracy = self.estimator.score(generalized, y_test)
print('Level: %d, accuracy: %f' % (level, accuracy))
print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy))
level+=1
# if accuracy below threshold, improve accuracy by removing features from generalization
if accuracy < self.target_accuracy:
print('Improving accuracy')
while accuracy < self.target_accuracy:
self._calculate_generalizations()
removed_feature = self._remove_feature_from_generalization(X_test,
nodes, y_test,
feature_data)
if not removed_feature:
feature_data, accuracy)
if removed_feature is None:
break
generalized = self._generalize(X_test, nodes, self.cells_,
self.cells_by_id_)
self._calculate_generalizations()
generalized = self._generalize(X_test, nodes, self.cells_, self.cells_by_id_)
accuracy = self.estimator.score(generalized, y_test)
print('Removed feature: %s, accuracy: %f' % (removed_feature, accuracy))
print('Removed feature: %s, new relative accuracy: %f' % (removed_feature, accuracy))
# self.cells_ currently holds the chosen generalization based on target accuracy
@ -304,12 +313,12 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# replace the values in the representative columns with the representative
# values (leaves others untouched)
if not representatives.columns.empty:
if indexes and not representatives.columns.empty:
if len(indexes) > 1:
replace = pd.concat([representatives.loc[i].to_frame().T]*len(indexes)).reset_index(drop=True)
replace.index = indexes
else:
replace = representatives.loc[i].to_frame().T
replace = representatives.loc[i].to_frame().T.reset_index(drop=True)
replace.index = indexes
generalized.loc[indexes, representatives.columns] = replace
return generalized.to_numpy()
@ -409,30 +418,31 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
new_cells = []
new_cells_by_id = {}
nodes = self._get_nodes_level(level)
for node in nodes:
if self.dt_.tree_.feature[node] == -2: # leaf node
new_cell = self.cells_by_id_[node]
else:
left_child = self.dt_.tree_.children_left[node]
right_child = self.dt_.tree_.children_right[node]
left_cell = self.cells_by_id_[left_child]
right_cell = self.cells_by_id_[right_child]
new_cell = {'id': int(node), 'ranges': {}, 'categories': {},
'label': None, 'representative': None}
for feature in left_cell['ranges'].keys():
new_cell['ranges'][feature] = {}
new_cell['ranges'][feature]['start'] = left_cell['ranges'][feature]['start']
new_cell['ranges'][feature]['end'] = right_cell['ranges'][feature]['start']
for feature in left_cell['categories'].keys():
new_cell['categories'][feature] = \
list(set(left_cell['categories'][feature]) |
set(right_cell['categories'][feature]))
self._calculate_level_cell_label(left_cell, right_cell, new_cell)
new_cells.append(new_cell)
new_cells_by_id[new_cell['id']] = new_cell
self.cells_ = new_cells
self.cells_by_id_ = new_cells_by_id
# else: nothing to do, stay with previous cells
if nodes:
for node in nodes:
if self.dt_.tree_.feature[node] == -2: # leaf node
new_cell = self.cells_by_id_[node]
else:
left_child = self.dt_.tree_.children_left[node]
right_child = self.dt_.tree_.children_right[node]
left_cell = self.cells_by_id_[left_child]
right_cell = self.cells_by_id_[right_child]
new_cell = {'id': int(node), 'ranges': {}, 'categories': {},
'label': None, 'representative': None}
for feature in left_cell['ranges'].keys():
new_cell['ranges'][feature] = {}
new_cell['ranges'][feature]['start'] = left_cell['ranges'][feature]['start']
new_cell['ranges'][feature]['end'] = right_cell['ranges'][feature]['start']
for feature in left_cell['categories'].keys():
new_cell['categories'][feature] = \
list(set(left_cell['categories'][feature]) |
set(right_cell['categories'][feature]))
self._calculate_level_cell_label(left_cell, right_cell, new_cell)
new_cells.append(new_cell)
new_cells_by_id[new_cell['id']] = new_cell
self.cells_ = new_cells
self.cells_by_id_ = new_cells_by_id
# else: nothing to do, stay with previous cells
def _calculate_level_cell_label(self, left_cell, right_cell, new_cell):
new_cell['hist'] = [x + y for x, y in zip(left_cell['hist'], right_cell['hist'])]
@ -445,6 +455,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
stack = [(0, -1)] # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
# depth = distance from root
node_depth[node_id] = parent_depth + 1
if self.dt_.tree_.children_left[node_id] != self.dt_.tree_.children_right[node_id]:
@ -453,10 +464,14 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
else:
is_leaves[node_id] = True
# depth of entire tree
max_depth = max(node_depth)
# depth of current level
depth = max_depth - level
# level is higher than root
if depth < 0:
return None
# return all nodes with depth == level or leaves higher than level
return [i for i, x in enumerate(node_depth) if x == depth or (x < depth and is_leaves[i])]
def _attach_cells_representatives(self, samples, labels, level_nodes):
@ -518,12 +533,12 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
indexes = [j for j in range(len(mapping_to_cells)) if mapping_to_cells[j]['id'] == cells[i]['id']]
# replaces the values in the representative columns with the representative values
# (leaves others untouched)
if not representatives.columns.empty:
if indexes and not representatives.columns.empty:
if len(indexes) > 1:
replace = pd.concat([representatives.loc[i].to_frame().T]*len(indexes)).reset_index(drop=True)
replace.index = indexes
else:
replace = representatives.loc[i].to_frame().T
replace = representatives.loc[i].to_frame().T.reset_index(drop=True)
replace.index = indexes
generalized.loc[indexes, representatives.columns] = replace
return generalized.to_numpy()
@ -539,14 +554,16 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
node_ids = self._find_sample_nodes(samples, nodes)
return [cells_by_id[nodeId] for nodeId in node_ids]
def _remove_feature_from_generalization(self, samples, nodes, labels, feature_data):
feature = self._get_feature_to_remove(samples, nodes, labels, feature_data)
if not feature:
def _remove_feature_from_generalization(self, samples, nodes, labels, feature_data, current_accuracy):
feature = self._get_feature_to_remove(samples, nodes, labels, feature_data, current_accuracy)
if feature is None:
return None
GeneralizeToRepresentative._remove_feature_from_cells(self.cells_, self.cells_by_id_, feature)
# del self.generalizations_['ranges'][feature]
# self.generalizations_['untouched'].append(feature)
return feature
def _get_feature_to_remove(self, samples, nodes, labels, feature_data):
def _get_feature_to_remove(self, samples, nodes, labels, feature_data, current_accuracy):
# We want to remove features with low iLoss (NCP) and high accuracy gain
# (after removing them)
ranges = self.generalizations_['ranges']
@ -567,13 +584,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
cells_by_id = copy.deepcopy(self.cells_by_id_)
GeneralizeToRepresentative._remove_feature_from_cells(new_cells, cells_by_id, feature)
generalized = self._generalize(samples, nodes, new_cells, cells_by_id)
accuracy = self.estimator.score(generalized, labels)
feature_ncp = feature_ncp / accuracy
accuracy_gain = self.estimator.score(generalized, labels) - current_accuracy
if accuracy_gain < 0:
accuracy_gain = 0
if accuracy_gain != 0:
feature_ncp = feature_ncp / accuracy_gain
if feature_ncp < range_min:
range_min = feature_ncp
remove_feature = feature
print('feature to remove: ' + (remove_feature if remove_feature else ''))
print('feature to remove: ' + (str(remove_feature) if remove_feature is not None else 'none'))
return remove_feature
def _calculate_generalizations(self):
@ -660,5 +681,3 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
del cell['categories'][feature]
cell['untouched'].append(feature)
cells_by_id[cell['id']] = cell.copy()