Fix bug in pruning loop + fix test

This commit is contained in:
abigailt 2022-05-19 17:49:59 +03:00 committed by abigailgold
parent 186f11eaaf
commit 7055d5ecf6
2 changed files with 29 additions and 25 deletions

View file

@ -324,31 +324,34 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
print('Improving generalizations')
level = 1
while accuracy > self.target_accuracy:
try:
cells_previous_iter = self.cells
generalization_prev_iter = self._generalizations
cells_by_id_prev = self._cells_by_id
nodes = self._get_nodes_level(level)
self._calculate_level_cells(level)
self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes)
cells_previous_iter = self.cells
generalization_prev_iter = self._generalizations
cells_by_id_prev = self._cells_by_id
nodes = self._get_nodes_level(level)
self._calculate_generalizations()
generalized = self._generalize(X_test, x_prepared_test, nodes, self.cells,
self._cells_by_id)
accuracy = self.estimator.score(ArrayDataset(preprocessor.transform(generalized), y_test))
# if accuracy passed threshold roll back to previous iteration generalizations
if accuracy < self.target_accuracy:
self.cells = cells_previous_iter
self._generalizations = generalization_prev_iter
self._cells_by_id = cells_by_id_prev
break
else:
print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy))
level += 1
except Exception as e:
try:
self._calculate_level_cells(level)
except TypeError as e:
print(e)
break
self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes)
self._calculate_generalizations()
generalized = self._generalize(X_test, x_prepared_test, nodes, self.cells,
self._cells_by_id)
accuracy = self.estimator.score(ArrayDataset(preprocessor.transform(generalized), y_test))
# if accuracy passed threshold roll back to previous iteration generalizations
if accuracy < self.target_accuracy:
self.cells = cells_previous_iter
self._generalizations = generalization_prev_iter
self._cells_by_id = cells_by_id_prev
break
else:
print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy))
level += 1
# if accuracy below threshold, improve accuracy by removing features from generalization
elif accuracy < self.target_accuracy:
print('Improving accuracy')
@ -569,7 +572,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
features = self._categorical_data.columns
for cell in self.cells:
new_cell = {'id': cell['id'], 'label': cell['label'], 'ranges': {}, 'categories': {}, 'hist': cell['hist'],
'representative': None}
'untouched': [], 'representative': None}
for feature in features:
if feature in self._one_hot_vector_features_to_features.keys():
# feature is categorical and should be mapped