Fix bug in pruning loop + fix test

This commit is contained in:
abigailt 2022-05-19 17:49:59 +03:00 committed by abigailgold
parent 186f11eaaf
commit 7055d5ecf6
2 changed files with 29 additions and 25 deletions

View file

@ -324,12 +324,17 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
print('Improving generalizations') print('Improving generalizations')
level = 1 level = 1
while accuracy > self.target_accuracy: while accuracy > self.target_accuracy:
try:
cells_previous_iter = self.cells cells_previous_iter = self.cells
generalization_prev_iter = self._generalizations generalization_prev_iter = self._generalizations
cells_by_id_prev = self._cells_by_id cells_by_id_prev = self._cells_by_id
nodes = self._get_nodes_level(level) nodes = self._get_nodes_level(level)
try:
self._calculate_level_cells(level) self._calculate_level_cells(level)
except TypeError as e:
print(e)
break
self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes) self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes)
self._calculate_generalizations() self._calculate_generalizations()
@ -345,9 +350,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
else: else:
print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy)) print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy))
level += 1 level += 1
except Exception as e:
print(e)
break
# if accuracy below threshold, improve accuracy by removing features from generalization # if accuracy below threshold, improve accuracy by removing features from generalization
elif accuracy < self.target_accuracy: elif accuracy < self.target_accuracy:
@ -569,7 +572,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
features = self._categorical_data.columns features = self._categorical_data.columns
for cell in self.cells: for cell in self.cells:
new_cell = {'id': cell['id'], 'label': cell['label'], 'ranges': {}, 'categories': {}, 'hist': cell['hist'], new_cell = {'id': cell['id'], 'label': cell['label'], 'ranges': {}, 'categories': {}, 'hist': cell['hist'],
'representative': None} 'untouched': [], 'representative': None}
for feature in features: for feature in features:
if feature in self._one_hot_vector_features_to_features.keys(): if feature in self._one_hot_vector_features_to_features.keys():
# feature is categorical and should be mapped # feature is categorical and should be mapped

View file

@ -1,6 +1,8 @@
import pytest import pytest
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from numpy.testing import assert_almost_equal
from sklearn.compose import ColumnTransformer from sklearn.compose import ColumnTransformer
from sklearn.datasets import load_boston, load_diabetes from sklearn.datasets import load_boston, load_diabetes
@ -912,13 +914,12 @@ def test_blackbox_model():
gen.fit(dataset=train_dataset) gen.fit(dataset=train_dataset)
transformed = gen.transform(dataset=ad) transformed = gen.transform(dataset=ad)
gener = gen.generalizations gener = gen.generalizations
expected_generalizations = {'ranges': {'0': [], '1': [], '2': [4.849999904632568, 5.049999952316284], expected_generalizations = {'ranges': {'0': [], '1': [], '2': [4.849999904632568], '3': [0.7000000029802322]},
'3': [0.7000000029802322, 1.600000023841858]},
'categories': {}, 'categories': {},
'untouched': []} 'untouched': []}
for key in expected_generalizations['ranges']: for key in expected_generalizations['ranges']:
assert (set(expected_generalizations['ranges'][key]) == set(gener['ranges'][key])) assert_almost_equal(expected_generalizations['ranges'][key], gener['ranges'][key])
for key in expected_generalizations['categories']: for key in expected_generalizations['categories']:
assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]]) == assert (set([frozenset(sl) for sl in expected_generalizations['categories'][key]]) ==
set([frozenset(sl) for sl in gener['categories'][key]])) set([frozenset(sl) for sl in gener['categories'][key]]))